Forráskód Böngészése

feat(lora): add target_modules to LoraConfig for selective fine-tuning (#1231)

* feat(lora): add target_modules to LoraConfig for selective fine-tuning

Adds a target_modules field to LoraConfig that controls which parts of
the model receive LoRA adapters. Valid values are "attention", "mlp",
"embeddings", and "output". Defaults to all four for full backwards
compatibility.

This enables training LoRA only on a subset of the model, for example
attention and MLP layers only (skipping embeddings and output), which
reduces trainable parameters and avoids modifying the token embedding
space.

The output layer is now only targeted when it actually exists on the
model (hasattr guard), which also fixes the AttributeError crash when
tie_word_embeddings=True (reported in #1195, also addressed by #1210,
#1213, #1220).

Example config for attention+MLP only LoRA:
    target_modules: [attention, mlp]

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat(lora): extend target_modules with fast_* prefixes for fast-only LoRA

Adds "fast_attention", "fast_mlp", "fast_embeddings", "fast_output" as
valid target_modules values, targeting only the fast transformer of the
dual-AR model and leaving the slow transformer fully frozen.

Unprefixed names ("attention", "mlp", etc.) continue to target both slow
and fast layers as before, preserving full backwards compatibility.

Also adds r_32_alpha_16_fast.yaml as a ready-to-use config for fast-only
LoRA (r=32, alpha=16), which in practice converges in ~500 steps on a
single-speaker dataset.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Matteo 2 hete
szülő
commit
cb19a65dfb

+ 9 - 0
fish_speech/configs/lora/r_32_alpha_16_fast.yaml

@@ -0,0 +1,9 @@
+_target_: fish_speech.models.text2semantic.lora.LoraConfig
+r: 32
+lora_alpha: 16
+lora_dropout: 0.1
+target_modules:
+  - fast_attention
+  - fast_mlp
+  - fast_embeddings
+  - fast_output

+ 52 - 25
fish_speech/models/text2semantic/lora.py

@@ -1,4 +1,4 @@
-from dataclasses import dataclass
+from dataclasses import dataclass, field
 
 import loralib as lora
 
@@ -8,6 +8,13 @@ class LoraConfig:
     r: int
     lora_alpha: float
     lora_dropout: float = 0.0
+    # Valid values: "attention", "mlp", "embeddings", "output",
+    #               "fast_attention", "fast_mlp", "fast_embeddings", "fast_output"
+    # Unprefixed names target the slow transformer (and fast too for backwards compat).
+    # "fast_*" names target only the fast transformer.
+    target_modules: list = field(
+        default_factory=lambda: ["attention", "mlp", "embeddings", "output"]
+    )
 
 
 def _replace_embedding(old_embed, lora_config):
@@ -23,34 +30,34 @@ def _replace_embedding(old_embed, lora_config):
 
 
 def setup_lora(model, lora_config):
-    # Replace the embedding layer with a LoRA layer, preserving pretrained weights
-    model.embeddings = _replace_embedding(model.embeddings, lora_config)
-    model.codebook_embeddings = _replace_embedding(
-        model.codebook_embeddings, lora_config
-    )
-
-    # Replace output layer with a LoRA layer
-    linears = [(model, "output")]
-
-    # Replace all linear layers with LoRA layers
-    for layer in model.layers:
-        linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
-        linears.extend(
-            [
-                (layer.feed_forward, "w1"),
-                (layer.feed_forward, "w2"),
-                (layer.feed_forward, "w3"),
-            ]
+    targets = set(lora_config.target_modules)
+    linears = []
+
+    # Slow transformer: targeted by unprefixed names (e.g. "attention")
+    slow_attention = "attention" in targets
+    slow_mlp = "mlp" in targets
+    slow_embeddings = "embeddings" in targets
+    slow_output = "output" in targets
+
+    # Fast transformer: targeted by unprefixed names (backwards compat) OR "fast_*"
+    fast_attention = slow_attention or "fast_attention" in targets
+    fast_mlp = slow_mlp or "fast_mlp" in targets
+    fast_embeddings = slow_embeddings or "fast_embeddings" in targets
+    fast_output = slow_output or "fast_output" in targets
+
+    if slow_embeddings:
+        model.embeddings = _replace_embedding(model.embeddings, lora_config)
+        model.codebook_embeddings = _replace_embedding(
+            model.codebook_embeddings, lora_config
         )
 
-    if hasattr(model, "fast_layers"):
-        model.fast_embeddings = _replace_embedding(model.fast_embeddings, lora_config)
-
-        # Dual-AR model
-        linears.append((model, "fast_output"))
+    if slow_output and hasattr(model, "output"):
+        linears.append((model, "output"))
 
-        for layer in model.fast_layers:
+    for layer in model.layers:
+        if slow_attention:
             linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+        if slow_mlp:
             linears.extend(
                 [
                     (layer.feed_forward, "w1"),
@@ -59,6 +66,26 @@ def setup_lora(model, lora_config):
                 ]
             )
 
+    if hasattr(model, "fast_layers"):
+        if fast_embeddings:
+            model.fast_embeddings = _replace_embedding(
+                model.fast_embeddings, lora_config
+            )
+        if fast_output:
+            linears.append((model, "fast_output"))
+
+        for layer in model.fast_layers:
+            if fast_attention:
+                linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
+            if fast_mlp:
+                linears.extend(
+                    [
+                        (layer.feed_forward, "w1"),
+                        (layer.feed_forward, "w2"),
+                        (layer.feed_forward, "w3"),
+                    ]
+                )
+
     for module, layer_name in linears:
         old_linear = getattr(module, layer_name)
         updated_linear = lora.Linear(