lora.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from dataclasses import dataclass, field
  2. import loralib as lora
  3. @dataclass
  4. class LoraConfig:
  5. r: int
  6. lora_alpha: float
  7. lora_dropout: float = 0.0
  8. # Valid values: "attention", "mlp", "embeddings", "output",
  9. # "fast_attention", "fast_mlp", "fast_embeddings", "fast_output"
  10. # Unprefixed names target the slow transformer (and fast too for backwards compat).
  11. # "fast_*" names target only the fast transformer.
  12. target_modules: list = field(
  13. default_factory=lambda: ["attention", "mlp", "embeddings", "output"]
  14. )
  15. def _replace_embedding(old_embed, lora_config):
  16. new_embed = lora.Embedding(
  17. num_embeddings=old_embed.num_embeddings,
  18. embedding_dim=old_embed.embedding_dim,
  19. padding_idx=old_embed.padding_idx,
  20. r=lora_config.r,
  21. lora_alpha=lora_config.lora_alpha,
  22. )
  23. new_embed.weight.data.copy_(old_embed.weight.data)
  24. return new_embed
  25. def setup_lora(model, lora_config):
  26. targets = set(lora_config.target_modules)
  27. linears = []
  28. # Slow transformer: targeted by unprefixed names (e.g. "attention")
  29. slow_attention = "attention" in targets
  30. slow_mlp = "mlp" in targets
  31. slow_embeddings = "embeddings" in targets
  32. slow_output = "output" in targets
  33. # Fast transformer: targeted by unprefixed names (backwards compat) OR "fast_*"
  34. fast_attention = slow_attention or "fast_attention" in targets
  35. fast_mlp = slow_mlp or "fast_mlp" in targets
  36. fast_embeddings = slow_embeddings or "fast_embeddings" in targets
  37. fast_output = slow_output or "fast_output" in targets
  38. if slow_embeddings:
  39. model.embeddings = _replace_embedding(model.embeddings, lora_config)
  40. model.codebook_embeddings = _replace_embedding(
  41. model.codebook_embeddings, lora_config
  42. )
  43. if slow_output and hasattr(model, "output"):
  44. linears.append((model, "output"))
  45. for layer in model.layers:
  46. if slow_attention:
  47. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  48. if slow_mlp:
  49. linears.extend(
  50. [
  51. (layer.feed_forward, "w1"),
  52. (layer.feed_forward, "w2"),
  53. (layer.feed_forward, "w3"),
  54. ]
  55. )
  56. if hasattr(model, "fast_layers"):
  57. if fast_embeddings:
  58. model.fast_embeddings = _replace_embedding(
  59. model.fast_embeddings, lora_config
  60. )
  61. if fast_output:
  62. linears.append((model, "fast_output"))
  63. for layer in model.fast_layers:
  64. if fast_attention:
  65. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  66. if fast_mlp:
  67. linears.extend(
  68. [
  69. (layer.feed_forward, "w1"),
  70. (layer.feed_forward, "w2"),
  71. (layer.feed_forward, "w3"),
  72. ]
  73. )
  74. for module, layer_name in linears:
  75. old_linear = getattr(module, layer_name)
  76. updated_linear = lora.Linear(
  77. in_features=old_linear.in_features,
  78. out_features=old_linear.out_features,
  79. bias=old_linear.bias is not None,
  80. r=lora_config.r,
  81. lora_alpha=lora_config.lora_alpha,
  82. lora_dropout=lora_config.lora_dropout,
  83. )
  84. updated_linear.weight.data.copy_(old_linear.weight.data)
  85. if old_linear.bias is not None:
  86. updated_linear.bias.data.copy_(old_linear.bias.data)
  87. setattr(module, layer_name, updated_linear)
  88. # Mark only the LoRA layers as trainable
  89. lora.mark_only_lora_as_trainable(model, bias="none")
  90. def get_merged_state_dict(model):
  91. # This line will merge the state dict of the model and the LoRA parameters
  92. model.eval()
  93. # Then we need to remove the LoRA parameters from the state dict
  94. state_dict = model.state_dict()
  95. for name in list(state_dict.keys()):
  96. if "lora" in name:
  97. state_dict.pop(name)
  98. return state_dict