|
|
@@ -10,22 +10,23 @@ class LoraConfig:
|
|
|
lora_dropout: float = 0.0
|
|
|
|
|
|
|
|
|
-def setup_lora(model, lora_config):
|
|
|
- # Replace the embedding layer with a LoRA layer
|
|
|
- model.embeddings = lora.Embedding(
|
|
|
- num_embeddings=model.embeddings.num_embeddings,
|
|
|
- embedding_dim=model.embeddings.embedding_dim,
|
|
|
- padding_idx=model.embeddings.padding_idx,
|
|
|
+def _replace_embedding(old_embed, lora_config):
|
|
|
+ new_embed = lora.Embedding(
|
|
|
+ num_embeddings=old_embed.num_embeddings,
|
|
|
+ embedding_dim=old_embed.embedding_dim,
|
|
|
+ padding_idx=old_embed.padding_idx,
|
|
|
r=lora_config.r,
|
|
|
lora_alpha=lora_config.lora_alpha,
|
|
|
)
|
|
|
+ new_embed.weight.data.copy_(old_embed.weight.data)
|
|
|
+ return new_embed
|
|
|
|
|
|
- model.codebook_embeddings = lora.Embedding(
|
|
|
- num_embeddings=model.codebook_embeddings.num_embeddings,
|
|
|
- embedding_dim=model.codebook_embeddings.embedding_dim,
|
|
|
- padding_idx=model.codebook_embeddings.padding_idx,
|
|
|
- r=lora_config.r,
|
|
|
- lora_alpha=lora_config.lora_alpha,
|
|
|
+
|
|
|
+def setup_lora(model, lora_config):
|
|
|
+ # Replace the embedding layer with a LoRA layer, preserving pretrained weights
|
|
|
+ model.embeddings = _replace_embedding(model.embeddings, lora_config)
|
|
|
+ model.codebook_embeddings = _replace_embedding(
|
|
|
+ model.codebook_embeddings, lora_config
|
|
|
)
|
|
|
|
|
|
# Replace output layer with a LoRA layer
|
|
|
@@ -43,13 +44,7 @@ def setup_lora(model, lora_config):
|
|
|
)
|
|
|
|
|
|
if hasattr(model, "fast_layers"):
|
|
|
- model.fast_embeddings = lora.Embedding(
|
|
|
- num_embeddings=model.fast_embeddings.num_embeddings,
|
|
|
- embedding_dim=model.fast_embeddings.embedding_dim,
|
|
|
- padding_idx=model.fast_embeddings.padding_idx,
|
|
|
- r=lora_config.r,
|
|
|
- lora_alpha=lora_config.lora_alpha,
|
|
|
- )
|
|
|
+ model.fast_embeddings = _replace_embedding(model.fast_embeddings, lora_config)
|
|
|
|
|
|
# Dual-AR model
|
|
|
linears.append((model, "fast_output"))
|
|
|
@@ -64,16 +59,20 @@ def setup_lora(model, lora_config):
|
|
|
]
|
|
|
)
|
|
|
|
|
|
- for module, layer in linears:
|
|
|
+ for module, layer_name in linears:
|
|
|
+ old_linear = getattr(module, layer_name)
|
|
|
updated_linear = lora.Linear(
|
|
|
- in_features=getattr(module, layer).in_features,
|
|
|
- out_features=getattr(module, layer).out_features,
|
|
|
- bias=getattr(module, layer).bias,
|
|
|
+ in_features=old_linear.in_features,
|
|
|
+ out_features=old_linear.out_features,
|
|
|
+ bias=old_linear.bias is not None,
|
|
|
r=lora_config.r,
|
|
|
lora_alpha=lora_config.lora_alpha,
|
|
|
lora_dropout=lora_config.lora_dropout,
|
|
|
)
|
|
|
- setattr(module, layer, updated_linear)
|
|
|
+ updated_linear.weight.data.copy_(old_linear.weight.data)
|
|
|
+ if old_linear.bias is not None:
|
|
|
+ updated_linear.bias.data.copy_(old_linear.bias.data)
|
|
|
+ setattr(module, layer_name, updated_linear)
|
|
|
|
|
|
# Mark only the LoRA layers as trainable
|
|
|
lora.mark_only_lora_as_trainable(model, bias="none")
|