lora.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from dataclasses import dataclass
  2. import loralib as lora
  3. @dataclass
  4. class LoraConfig:
  5. r: int
  6. lora_alpha: float
  7. lora_dropout: float = 0.0
  8. def setup_lora(model, lora_config):
  9. # Replace the embedding layer with a LoRA layer
  10. model.embeddings = lora.Embedding(
  11. num_embeddings=model.embeddings.num_embeddings,
  12. embedding_dim=model.embeddings.embedding_dim,
  13. padding_idx=model.embeddings.padding_idx,
  14. r=lora_config.r,
  15. lora_alpha=lora_config.lora_alpha,
  16. )
  17. model.codebook_embeddings = lora.Embedding(
  18. num_embeddings=model.codebook_embeddings.num_embeddings,
  19. embedding_dim=model.codebook_embeddings.embedding_dim,
  20. padding_idx=model.codebook_embeddings.padding_idx,
  21. r=lora_config.r,
  22. lora_alpha=lora_config.lora_alpha,
  23. )
  24. # Replace output layer with a LoRA layer
  25. linears = [(model, "output")]
  26. # Replace all linear layers with LoRA layers
  27. for layer in model.layers:
  28. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  29. linears.extend(
  30. [
  31. (layer.feed_forward, "w1"),
  32. (layer.feed_forward, "w2"),
  33. (layer.feed_forward, "w3"),
  34. ]
  35. )
  36. if hasattr(model, "fast_layers"):
  37. model.fast_embeddings = lora.Embedding(
  38. num_embeddings=model.fast_embeddings.num_embeddings,
  39. embedding_dim=model.fast_embeddings.embedding_dim,
  40. padding_idx=model.fast_embeddings.padding_idx,
  41. r=lora_config.r,
  42. lora_alpha=lora_config.lora_alpha,
  43. )
  44. # Dual-AR model
  45. linears.append((model, "fast_output"))
  46. for layer in model.fast_layers:
  47. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  48. linears.extend(
  49. [
  50. (layer.feed_forward, "w1"),
  51. (layer.feed_forward, "w2"),
  52. (layer.feed_forward, "w3"),
  53. ]
  54. )
  55. for module, layer in linears:
  56. updated_linear = lora.Linear(
  57. in_features=getattr(module, layer).in_features,
  58. out_features=getattr(module, layer).out_features,
  59. bias=getattr(module, layer).bias,
  60. r=lora_config.r,
  61. lora_alpha=lora_config.lora_alpha,
  62. lora_dropout=lora_config.lora_dropout,
  63. )
  64. setattr(module, layer, updated_linear)
  65. # Mark only the LoRA layers as trainable
  66. lora.mark_only_lora_as_trainable(model, bias="none")
  67. def get_merged_state_dict(model):
  68. # This line will merge the state dict of the model and the LoRA parameters
  69. model.eval()
  70. # Then we need to remove the LoRA parameters from the state dict
  71. state_dict = model.state_dict()
  72. for name in list(state_dict.keys()):
  73. if "lora" in name:
  74. state_dict.pop(name)
  75. return state_dict