|
|
@@ -1,25 +1,17 @@
|
|
|
-from dataclasses import dataclass
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
import lightning as L
|
|
|
-import loralib as lora
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
|
|
|
|
|
import fish_speech.utils as utils
|
|
|
from fish_speech.models.text2semantic.llama import NaiveTransformer
|
|
|
+from fish_speech.models.text2semantic.lora_utils import LoraConfig, setup_lora
|
|
|
|
|
|
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
|
|
|
|
|
-@dataclass
|
|
|
-class LoraConfig:
|
|
|
- r: int
|
|
|
- lora_alpha: float
|
|
|
- lora_dropout: float = 0.0
|
|
|
-
|
|
|
-
|
|
|
class TextToSemantic(L.LightningModule):
|
|
|
def __init__(
|
|
|
self,
|
|
|
@@ -27,7 +19,6 @@ class TextToSemantic(L.LightningModule):
|
|
|
optimizer: Any,
|
|
|
lr_scheduler: Any,
|
|
|
lora_config: Optional[LoraConfig] = None,
|
|
|
- save_lora_only: bool = False,
|
|
|
use_dpo: bool = False,
|
|
|
dpo_beta: float = 0.2,
|
|
|
):
|
|
|
@@ -37,70 +28,17 @@ class TextToSemantic(L.LightningModule):
|
|
|
self.optimizer_builder = optimizer
|
|
|
self.lr_scheduler_builder = lr_scheduler
|
|
|
self.lora_config = lora_config
|
|
|
- self.save_lora_only = save_lora_only
|
|
|
self.use_dpo = use_dpo # We don't support reference model yet
|
|
|
self.dpo_beta = dpo_beta
|
|
|
|
|
|
if self.lora_config is not None:
|
|
|
- self.setup_lora()
|
|
|
-
|
|
|
- def setup_lora(self):
|
|
|
- # Replace the embedding layer with a LoRA layer
|
|
|
- self.model.embeddings = lora.Embedding(
|
|
|
- num_embeddings=self.model.embeddings.num_embeddings,
|
|
|
- embedding_dim=self.model.embeddings.embedding_dim,
|
|
|
- padding_idx=self.model.embeddings.padding_idx,
|
|
|
- r=self.lora_config.r,
|
|
|
- lora_alpha=self.lora_config.lora_alpha,
|
|
|
- )
|
|
|
-
|
|
|
- # Replace output layer with a LoRA layer
|
|
|
- linears = [(self.model, "output")]
|
|
|
-
|
|
|
- # Replace all linear layers with LoRA layers
|
|
|
- for layer in self.model.layers:
|
|
|
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
|
|
- linears.extend(
|
|
|
- [
|
|
|
- (layer.feed_forward, "w1"),
|
|
|
- (layer.feed_forward, "w2"),
|
|
|
- (layer.feed_forward, "w3"),
|
|
|
- ]
|
|
|
- )
|
|
|
-
|
|
|
- if hasattr(self.model, "fast_layers"):
|
|
|
- # Dual-AR model
|
|
|
- linears.extend([(self.model, "fast_output")])
|
|
|
-
|
|
|
- for layer in self.model.fast_layers:
|
|
|
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
|
|
- linears.extend(
|
|
|
- [
|
|
|
- (layer.feed_forward, "w1"),
|
|
|
- (layer.feed_forward, "w2"),
|
|
|
- (layer.feed_forward, "w3"),
|
|
|
- ]
|
|
|
- )
|
|
|
-
|
|
|
- for module, layer in linears:
|
|
|
- updated_linear = lora.Linear(
|
|
|
- in_features=getattr(module, layer).in_features,
|
|
|
- out_features=getattr(module, layer).out_features,
|
|
|
- bias=getattr(module, layer).bias,
|
|
|
- r=self.lora_config.r,
|
|
|
- lora_alpha=self.lora_config.lora_alpha,
|
|
|
- lora_dropout=self.lora_config.lora_dropout,
|
|
|
- )
|
|
|
- setattr(module, layer, updated_linear)
|
|
|
-
|
|
|
- # Mark only the LoRA layers as trainable
|
|
|
- lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
|
|
|
+ setup_lora(self.model, self.lora_config)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.model(x)
|
|
|
|
|
|
def on_save_checkpoint(self, checkpoint):
|
|
|
- if self.lora_config is None or self.save_lora_only is False:
|
|
|
+ if self.lora_config is None:
|
|
|
return
|
|
|
|
|
|
# Save only LoRA parameters
|
|
|
@@ -178,6 +116,11 @@ class TextToSemantic(L.LightningModule):
|
|
|
def _step(self, batch, batch_idx, stage: str):
|
|
|
is_train = stage == "train"
|
|
|
|
|
|
+ if is_train:
|
|
|
+ # Key part to make lora work
|
|
|
+ # Otherwise the parameters are merged, which lead to incorrect gradients
|
|
|
+ self.model.train()
|
|
|
+
|
|
|
# Do positive and negative samples in the same batch to speed up training
|
|
|
labels = batch["labels"]
|
|
|
outputs = self.model(
|