| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- from dataclasses import dataclass
- import loralib as lora
- @dataclass
- class LoraConfig:
- r: int
- lora_alpha: float
- lora_dropout: float = 0.0
- def setup_lora(model, lora_config):
- # Replace the embedding layer with a LoRA layer
- model.embeddings = lora.Embedding(
- num_embeddings=model.embeddings.num_embeddings,
- embedding_dim=model.embeddings.embedding_dim,
- padding_idx=model.embeddings.padding_idx,
- r=lora_config.r,
- lora_alpha=lora_config.lora_alpha,
- )
- # Replace output layer with a LoRA layer
- linears = [(model, "output")]
- # Replace all linear layers with LoRA layers
- for layer in model.layers:
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
- linears.extend(
- [
- (layer.feed_forward, "w1"),
- (layer.feed_forward, "w2"),
- (layer.feed_forward, "w3"),
- ]
- )
- if hasattr(model, "fast_layers"):
- model.fast_embeddings = lora.Embedding(
- num_embeddings=model.fast_embeddings.num_embeddings,
- embedding_dim=model.fast_embeddings.embedding_dim,
- padding_idx=model.fast_embeddings.padding_idx,
- r=lora_config.r,
- lora_alpha=lora_config.lora_alpha,
- )
- # Dual-AR model
- linears.append((model, "fast_output"))
- for layer in model.fast_layers:
- linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
- linears.extend(
- [
- (layer.feed_forward, "w1"),
- (layer.feed_forward, "w2"),
- (layer.feed_forward, "w3"),
- ]
- )
- for module, layer in linears:
- updated_linear = lora.Linear(
- in_features=getattr(module, layer).in_features,
- out_features=getattr(module, layer).out_features,
- bias=getattr(module, layer).bias,
- r=lora_config.r,
- lora_alpha=lora_config.lora_alpha,
- lora_dropout=lora_config.lora_dropout,
- )
- setattr(module, layer, updated_linear)
- # Mark only the LoRA layers as trainable
- lora.mark_only_lora_as_trainable(model, bias="none")
- def get_merged_state_dict(model):
- # This line will merge the state dict of the model and the LoRA parameters
- model.eval()
- # Then we need to remove the LoRA parameters from the state dict
- state_dict = model.state_dict()
- for name in list(state_dict.keys()):
- if "lora" in name:
- state_dict.pop(name)
- return state_dict
|