Преглед изворни кода

fix: preserve pretrained weights in setup_lora (#1170)

setup_lora was called after load_state_dict, causing lora.Embedding and
lora.Linear constructors to overwrite loaded weights with random init.
Now weights are copied into the new LoRA layers after construction.
Also fixes bias parameter being passed as a tensor instead of a bool.
zhuxiaoxuhit пре 1 месец
родитељ
комит
cccb997abb
1 измењених фајлова са 23 додато и 24 уклоњено
  1. 23 24
      fish_speech/models/text2semantic/lora.py

+ 23 - 24
fish_speech/models/text2semantic/lora.py

@@ -10,22 +10,23 @@ class LoraConfig:
     lora_dropout: float = 0.0
 
 
-def setup_lora(model, lora_config):
-    # Replace the embedding layer with a LoRA layer
-    model.embeddings = lora.Embedding(
-        num_embeddings=model.embeddings.num_embeddings,
-        embedding_dim=model.embeddings.embedding_dim,
-        padding_idx=model.embeddings.padding_idx,
+def _replace_embedding(old_embed, lora_config):
+    new_embed = lora.Embedding(
+        num_embeddings=old_embed.num_embeddings,
+        embedding_dim=old_embed.embedding_dim,
+        padding_idx=old_embed.padding_idx,
         r=lora_config.r,
         lora_alpha=lora_config.lora_alpha,
     )
+    new_embed.weight.data.copy_(old_embed.weight.data)
+    return new_embed
 
-    model.codebook_embeddings = lora.Embedding(
-        num_embeddings=model.codebook_embeddings.num_embeddings,
-        embedding_dim=model.codebook_embeddings.embedding_dim,
-        padding_idx=model.codebook_embeddings.padding_idx,
-        r=lora_config.r,
-        lora_alpha=lora_config.lora_alpha,
+
+def setup_lora(model, lora_config):
+    # Replace the embedding layer with a LoRA layer, preserving pretrained weights
+    model.embeddings = _replace_embedding(model.embeddings, lora_config)
+    model.codebook_embeddings = _replace_embedding(
+        model.codebook_embeddings, lora_config
     )
 
     # Replace output layer with a LoRA layer
@@ -43,13 +44,7 @@ def setup_lora(model, lora_config):
         )
 
     if hasattr(model, "fast_layers"):
-        model.fast_embeddings = lora.Embedding(
-            num_embeddings=model.fast_embeddings.num_embeddings,
-            embedding_dim=model.fast_embeddings.embedding_dim,
-            padding_idx=model.fast_embeddings.padding_idx,
-            r=lora_config.r,
-            lora_alpha=lora_config.lora_alpha,
-        )
+        model.fast_embeddings = _replace_embedding(model.fast_embeddings, lora_config)
 
         # Dual-AR model
         linears.append((model, "fast_output"))
@@ -64,16 +59,20 @@ def setup_lora(model, lora_config):
                 ]
             )
 
-    for module, layer in linears:
+    for module, layer_name in linears:
+        old_linear = getattr(module, layer_name)
         updated_linear = lora.Linear(
-            in_features=getattr(module, layer).in_features,
-            out_features=getattr(module, layer).out_features,
-            bias=getattr(module, layer).bias,
+            in_features=old_linear.in_features,
+            out_features=old_linear.out_features,
+            bias=old_linear.bias is not None,
             r=lora_config.r,
             lora_alpha=lora_config.lora_alpha,
             lora_dropout=lora_config.lora_dropout,
         )
-        setattr(module, layer, updated_linear)
+        updated_linear.weight.data.copy_(old_linear.weight.data)
+        if old_linear.bias is not None:
+            updated_linear.bias.data.copy_(old_linear.bias.data)
+        setattr(module, layer_name, updated_linear)
 
     # Mark only the LoRA layers as trainable
     lora.mark_only_lora_as_trainable(model, bias="none")