lora_utils.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from dataclasses import dataclass
  2. import loralib as lora
  3. @dataclass
  4. class LoraConfig:
  5. r: int
  6. lora_alpha: float
  7. lora_dropout: float = 0.0
  8. def setup_lora(model, lora_config):
  9. # Replace the embedding layer with a LoRA layer
  10. model.embeddings = lora.Embedding(
  11. num_embeddings=model.embeddings.num_embeddings,
  12. embedding_dim=model.embeddings.embedding_dim,
  13. padding_idx=model.embeddings.padding_idx,
  14. r=lora_config.r,
  15. lora_alpha=lora_config.lora_alpha,
  16. )
  17. # Replace output layer with a LoRA layer
  18. linears = [(model, "output")]
  19. # Replace all linear layers with LoRA layers
  20. for layer in model.layers:
  21. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  22. linears.extend(
  23. [
  24. (layer.feed_forward, "w1"),
  25. (layer.feed_forward, "w2"),
  26. (layer.feed_forward, "w3"),
  27. ]
  28. )
  29. if hasattr(model, "fast_layers"):
  30. model.fast_embeddings = lora.Embedding(
  31. num_embeddings=model.fast_embeddings.num_embeddings,
  32. embedding_dim=model.fast_embeddings.embedding_dim,
  33. padding_idx=model.fast_embeddings.padding_idx,
  34. r=lora_config.r,
  35. lora_alpha=lora_config.lora_alpha,
  36. )
  37. # Dual-AR model
  38. linears.append((model, "fast_output"))
  39. for layer in model.fast_layers:
  40. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  41. linears.extend(
  42. [
  43. (layer.feed_forward, "w1"),
  44. (layer.feed_forward, "w2"),
  45. (layer.feed_forward, "w3"),
  46. ]
  47. )
  48. for module, layer in linears:
  49. updated_linear = lora.Linear(
  50. in_features=getattr(module, layer).in_features,
  51. out_features=getattr(module, layer).out_features,
  52. bias=getattr(module, layer).bias,
  53. r=lora_config.r,
  54. lora_alpha=lora_config.lora_alpha,
  55. lora_dropout=lora_config.lora_dropout,
  56. )
  57. setattr(module, layer, updated_linear)
  58. # Mark only the LoRA layers as trainable
  59. lora.mark_only_lora_as_trainable(model, bias="none")
  60. def get_merged_state_dict(model):
  61. # This line will merge the state dict of the model and the LoRA parameters
  62. model.eval()
  63. # Then we need to remove the LoRA parameters from the state dict
  64. state_dict = model.state_dict()
  65. for name in list(state_dict.keys()):
  66. if "lora" in name:
  67. state_dict.pop(name)
  68. return state_dict