Lengyue 2 лет назад
Родитель
Сommit
cdecc2abbc

+ 89 - 0
fish_speech/configs/text2semantic_finetune_lora.yaml

@@ -0,0 +1,89 @@
+defaults:
+  - base
+  - _self_
+
+project: text2semantic_400m_finetune_lora
+max_length: 4096
+ckpt_path: checkpoints/text2semantic-400m-v0.3-4k.pth
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+  accumulate_grad_batches: 2
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  max_steps: 1000
+  precision: bf16-true
+  limit_val_batches: 10
+  log_every_n_steps: 10
+
+# Dataset Configuration
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: fishaudio/speech-lm-v1
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+val_dataset:
+  _target_: fish_speech.datasets.text.AutoAugTextDataset
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+data:
+  _target_: fish_speech.datasets.text.TextDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 8
+  tokenizer: ${tokenizer}
+  max_length: ${max_length}
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.text2semantic.TextToSemantic
+
+  model:
+    _target_: fish_speech.models.text2semantic.llama.Transformer
+    config:
+      _target_: fish_speech.models.text2semantic.llama.ModelArgs
+      max_seq_len: 4096
+      vocab_size: 36408
+      n_layer: 24
+      n_head: 16
+      dim: 1024
+      rope_base: 10000
+      norm_eps: 1e-5
+      num_codebooks: 4  # single codebook
+      codebook_size: 168 # codebook size 160 + 2 special tokens
+
+  lora_config:
+    _target_: fish_speech.models.text2semantic.lit_module.LoraConfig
+    r: 8
+    lora_alpha: 16
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 3e-4
+    weight_decay: 0.1
+    betas: [0.9, 0.95]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 100
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.1
+
+# Callbacks
+callbacks:
+  model_checkpoint:
+    every_n_train_steps: 200

+ 60 - 2
fish_speech/models/text2semantic/lit_module.py

@@ -1,23 +1,81 @@
 import platform
-from typing import Any, Optional
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
 
 import lightning as L
+import loralib as lora
 import torch
 import torch.nn.functional as F
 from lightning.pytorch.utilities.types import OptimizerLRScheduler
 
 import fish_speech.utils as utils
+from fish_speech.models.text2semantic.llama import Transformer
 
 log = utils.RankedLogger(__name__, rank_zero_only=True)
 
 
+@dataclass
+class LoraConfig:
+    r: int
+    lora_alpha: float
+    lora_dropout: float = 0.0
+
+
 class TextToSemantic(L.LightningModule):
-    def __init__(self, model, optimizer: Any, lr_scheduler: Any):
+    def __init__(
+        self,
+        model: Transformer,
+        optimizer: Any,
+        lr_scheduler: Any,
+        lora_config: Optional[LoraConfig] = None,
+    ):
         super().__init__()
 
         self.model = model
         self.optimizer_builder = optimizer
         self.lr_scheduler_builder = lr_scheduler
+        self.lora_config = lora_config
+
+        if self.lora_config is not None:
+            self.setup_lora()
+
+    def setup_lora(self):
+        # Replace the embedding layer with a LoRA layer
+        self.model.embeddings = lora.Embedding(
+            num_embeddings=self.model.embeddings.num_embeddings,
+            embedding_dim=self.model.embeddings.embedding_dim,
+            padding_idx=self.model.embeddings.padding_idx,
+            r=self.lora_config.r,
+            lora_alpha=self.lora_config.lora_alpha,
+        )
+
+        # Replace output layer with a LoRA layer
+        linears = [(self.model, "output")]
+
+        # Replace all linear layers with LoRA layers
+        for layer in self.model.layers:
+            linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+            linears.extend(
+                [
+                    (layer.feed_forward, "w1"),
+                    (layer.feed_forward, "w2"),
+                    (layer.feed_forward, "w3"),
+                ]
+            )
+
+        for module, layer in linears:
+            updated_linear = lora.Linear(
+                in_features=getattr(module, layer).in_features,
+                out_features=getattr(module, layer).out_features,
+                bias=getattr(module, layer).bias,
+                r=self.lora_config.r,
+                lora_alpha=self.lora_config.lora_alpha,
+                lora_dropout=self.lora_config.lora_dropout,
+            )
+            setattr(module, layer, updated_linear)
+
+        # Mark only the LoRA layers as trainable
+        lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
 
     def forward(self, x):
         return self.model(x)