lora.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from dataclasses import dataclass
  2. import loralib as lora
  3. @dataclass
  4. class LoraConfig:
  5. r: int
  6. lora_alpha: float
  7. lora_dropout: float = 0.0
  8. def _replace_embedding(old_embed, lora_config):
  9. new_embed = lora.Embedding(
  10. num_embeddings=old_embed.num_embeddings,
  11. embedding_dim=old_embed.embedding_dim,
  12. padding_idx=old_embed.padding_idx,
  13. r=lora_config.r,
  14. lora_alpha=lora_config.lora_alpha,
  15. )
  16. new_embed.weight.data.copy_(old_embed.weight.data)
  17. return new_embed
  18. def setup_lora(model, lora_config):
  19. # Replace the embedding layer with a LoRA layer, preserving pretrained weights
  20. model.embeddings = _replace_embedding(model.embeddings, lora_config)
  21. model.codebook_embeddings = _replace_embedding(
  22. model.codebook_embeddings, lora_config
  23. )
  24. # Replace output layer with a LoRA layer
  25. linears = [(model, "output")]
  26. # Replace all linear layers with LoRA layers
  27. for layer in model.layers:
  28. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  29. linears.extend(
  30. [
  31. (layer.feed_forward, "w1"),
  32. (layer.feed_forward, "w2"),
  33. (layer.feed_forward, "w3"),
  34. ]
  35. )
  36. if hasattr(model, "fast_layers"):
  37. model.fast_embeddings = _replace_embedding(model.fast_embeddings, lora_config)
  38. # Dual-AR model
  39. linears.append((model, "fast_output"))
  40. for layer in model.fast_layers:
  41. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  42. linears.extend(
  43. [
  44. (layer.feed_forward, "w1"),
  45. (layer.feed_forward, "w2"),
  46. (layer.feed_forward, "w3"),
  47. ]
  48. )
  49. for module, layer_name in linears:
  50. old_linear = getattr(module, layer_name)
  51. updated_linear = lora.Linear(
  52. in_features=old_linear.in_features,
  53. out_features=old_linear.out_features,
  54. bias=old_linear.bias is not None,
  55. r=lora_config.r,
  56. lora_alpha=lora_config.lora_alpha,
  57. lora_dropout=lora_config.lora_dropout,
  58. )
  59. updated_linear.weight.data.copy_(old_linear.weight.data)
  60. if old_linear.bias is not None:
  61. updated_linear.bias.data.copy_(old_linear.bias.data)
  62. setattr(module, layer_name, updated_linear)
  63. # Mark only the LoRA layers as trainable
  64. lora.mark_only_lora_as_trainable(model, bias="none")
  65. def get_merged_state_dict(model):
  66. # This line will merge the state dict of the model and the LoRA parameters
  67. model.eval()
  68. # Then we need to remove the LoRA parameters from the state dict
  69. state_dict = model.state_dict()
  70. for name in list(state_dict.keys()):
  71. if "lora" in name:
  72. state_dict.pop(name)
  73. return state_dict