|
|
@@ -1,4 +1,4 @@
|
|
|
-from dataclasses import dataclass
|
|
|
+from dataclasses import dataclass, field
|
|
|
|
|
|
import loralib as lora
|
|
|
|
|
|
@@ -8,6 +8,13 @@ class LoraConfig:
|
|
|
r: int
|
|
|
lora_alpha: float
|
|
|
lora_dropout: float = 0.0
|
|
|
+ # Valid values: "attention", "mlp", "embeddings", "output",
|
|
|
+ # "fast_attention", "fast_mlp", "fast_embeddings", "fast_output"
|
|
|
+ # Unprefixed names target the slow transformer (and fast too for backwards compat).
|
|
|
+ # "fast_*" names target only the fast transformer.
|
|
|
+ target_modules: list = field(
|
|
|
+ default_factory=lambda: ["attention", "mlp", "embeddings", "output"]
|
|
|
+ )
|
|
|
|
|
|
|
|
|
def _replace_embedding(old_embed, lora_config):
|
|
|
@@ -23,34 +30,34 @@ def _replace_embedding(old_embed, lora_config):
|
|
|
|
|
|
|
|
|
def setup_lora(model, lora_config):
|
|
|
- # Replace the embedding layer with a LoRA layer, preserving pretrained weights
|
|
|
- model.embeddings = _replace_embedding(model.embeddings, lora_config)
|
|
|
- model.codebook_embeddings = _replace_embedding(
|
|
|
- model.codebook_embeddings, lora_config
|
|
|
- )
|
|
|
-
|
|
|
- # Replace output layer with a LoRA layer
|
|
|
- linears = [(model, "output")]
|
|
|
-
|
|
|
- # Replace all linear layers with LoRA layers
|
|
|
- for layer in model.layers:
|
|
|
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
|
|
- linears.extend(
|
|
|
- [
|
|
|
- (layer.feed_forward, "w1"),
|
|
|
- (layer.feed_forward, "w2"),
|
|
|
- (layer.feed_forward, "w3"),
|
|
|
- ]
|
|
|
+ targets = set(lora_config.target_modules)
|
|
|
+ linears = []
|
|
|
+
|
|
|
+ # Slow transformer: targeted by unprefixed names (e.g. "attention")
|
|
|
+ slow_attention = "attention" in targets
|
|
|
+ slow_mlp = "mlp" in targets
|
|
|
+ slow_embeddings = "embeddings" in targets
|
|
|
+ slow_output = "output" in targets
|
|
|
+
|
|
|
+ # Fast transformer: targeted by unprefixed names (backwards compat) OR "fast_*"
|
|
|
+ fast_attention = slow_attention or "fast_attention" in targets
|
|
|
+ fast_mlp = slow_mlp or "fast_mlp" in targets
|
|
|
+ fast_embeddings = slow_embeddings or "fast_embeddings" in targets
|
|
|
+ fast_output = slow_output or "fast_output" in targets
|
|
|
+
|
|
|
+ if slow_embeddings:
|
|
|
+ model.embeddings = _replace_embedding(model.embeddings, lora_config)
|
|
|
+ model.codebook_embeddings = _replace_embedding(
|
|
|
+ model.codebook_embeddings, lora_config
|
|
|
)
|
|
|
|
|
|
- if hasattr(model, "fast_layers"):
|
|
|
- model.fast_embeddings = _replace_embedding(model.fast_embeddings, lora_config)
|
|
|
-
|
|
|
- # Dual-AR model
|
|
|
- linears.append((model, "fast_output"))
|
|
|
+ if slow_output and hasattr(model, "output"):
|
|
|
+ linears.append((model, "output"))
|
|
|
|
|
|
- for layer in model.fast_layers:
|
|
|
+ for layer in model.layers:
|
|
|
+ if slow_attention:
|
|
|
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
|
|
+ if slow_mlp:
|
|
|
linears.extend(
|
|
|
[
|
|
|
(layer.feed_forward, "w1"),
|
|
|
@@ -59,6 +66,26 @@ def setup_lora(model, lora_config):
|
|
|
]
|
|
|
)
|
|
|
|
|
|
+ if hasattr(model, "fast_layers"):
|
|
|
+ if fast_embeddings:
|
|
|
+ model.fast_embeddings = _replace_embedding(
|
|
|
+ model.fast_embeddings, lora_config
|
|
|
+ )
|
|
|
+ if fast_output:
|
|
|
+ linears.append((model, "fast_output"))
|
|
|
+
|
|
|
+ for layer in model.fast_layers:
|
|
|
+ if fast_attention:
|
|
|
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
|
|
+ if fast_mlp:
|
|
|
+ linears.extend(
|
|
|
+ [
|
|
|
+ (layer.feed_forward, "w1"),
|
|
|
+ (layer.feed_forward, "w2"),
|
|
|
+ (layer.feed_forward, "w3"),
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
for module, layer_name in linears:
|
|
|
old_linear = getattr(module, layer_name)
|
|
|
updated_linear = lora.Linear(
|