Ver Fonte

Add neft and save lora only

Lengyue há 2 anos atrás
pai
commit
7ac4d4b918

+ 10 - 0
fish_speech/models/text2semantic/lit_module.py

@@ -80,6 +80,16 @@ class TextToSemantic(L.LightningModule):
     def forward(self, x):
         return self.model(x)
 
+    def on_save_checkpoint(self, checkpoint):
+        if self.lora_config is None:
+            return
+
+        # Save the LoRA parameters
+        state_dict = checkpoint["state_dict"]
+        for name in list(state_dict.keys()):
+            if "lora" not in name:
+                state_dict.pop(name)
+
     def configure_optimizers(self) -> OptimizerLRScheduler:
         optimizer = self.optimizer_builder(self.parameters())
         lr_scheduler = self.lr_scheduler_builder(optimizer)

+ 12 - 1
fish_speech/models/text2semantic/llama.py

@@ -1,3 +1,4 @@
+import math
 from dataclasses import dataclass
 from typing import Optional
 
@@ -44,6 +45,9 @@ class ModelArgs:
     # Gradient checkpointing
     use_gradient_checkpointing: bool = True
 
+    # NEFT
+    neft_alpha: float = 0
+
     def __post_init__(self):
         if self.n_local_heads == -1:
             self.n_local_heads = self.n_head
@@ -156,7 +160,14 @@ class Transformer(nn.Module):
             vocab_embeds.append(emb)
 
         x = torch.stack(vocab_embeds, dim=3)
-        return x.sum(dim=3)
+        x = x.sum(dim=3)
+
+        if self.config.neft_alpha > 0:
+            # alpha / sqrt(L * D)
+            scaled_alpha = self.config.neft_alpha / math.sqrt(
+                self.config.dim * x.shape[2]
+            )
+            x += torch.rand_like(x) * scaled_alpha
 
     def compute(
         self,