merge_lora.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import click
  2. import hydra
  3. import torch
  4. from hydra import compose, initialize
  5. from hydra.utils import instantiate
  6. from loguru import logger
  7. from fish_speech.models.text2semantic.lora_utils import (
  8. get_merged_state_dict,
  9. setup_lora,
  10. )
  11. @click.command()
  12. @click.option("--llama-config", type=str, default="dual_ar_2_codebook_medium")
  13. @click.option("--lora-config", type=str, default="r_8_alpha_16")
  14. @click.option(
  15. "--llama-weight", type=str, default="checkpoints/text2semantic-sft-medium-v1-4k.pth"
  16. )
  17. @click.option("--lora-weight", type=str, required=True)
  18. @click.option("--output", type=str, required=True)
  19. def merge(llama_config, lora_config, llama_weight, lora_weight, output):
  20. logger.info(
  21. f"Merging {llama_weight} and {lora_weight} into {output} with configs {llama_config} and {lora_config}"
  22. )
  23. hydra.core.global_hydra.GlobalHydra.instance().clear()
  24. with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
  25. # The max_seq_len here doesn't matter.
  26. cfg = compose(config_name=llama_config, overrides=[f"config.max_seq_len=2048"])
  27. llama_model = instantiate(cfg)
  28. logger.info(f"Loaded llama model with config {llama_config}")
  29. hydra.core.global_hydra.GlobalHydra.instance().clear()
  30. with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
  31. cfg = compose(config_name=lora_config)
  32. lora_config = instantiate(cfg)
  33. logger.info(f"Loaded lora model with config {lora_config}")
  34. setup_lora(llama_model, lora_config)
  35. logger.info(f"Merged model setup complete")
  36. llama_state_dict = torch.load(llama_weight, map_location="cpu")
  37. lora_state_dict = torch.load(lora_weight, map_location="cpu")
  38. if "state_dict" in llama_state_dict:
  39. llama_state_dict = llama_state_dict["state_dict"]
  40. if "state_dict" in lora_state_dict:
  41. lora_state_dict = lora_state_dict["state_dict"]
  42. # remove prefix model.
  43. llama_state_dict = {
  44. k.replace("model.", ""): v
  45. for k, v in llama_state_dict.items()
  46. if k.startswith("model.")
  47. }
  48. lora_state_dict = {
  49. k.replace("model.", ""): v
  50. for k, v in lora_state_dict.items()
  51. if k.startswith("model.")
  52. }
  53. logger.info(f"Found {len(llama_state_dict)} keys in llama model")
  54. logger.info(f"Found {len(lora_state_dict)} keys in lora model")
  55. merged_state_dict = llama_state_dict | lora_state_dict
  56. llama_model.load_state_dict(merged_state_dict, strict=True)
  57. logger.info(f"Merged model loaded")
  58. state_dict = get_merged_state_dict(llama_model)
  59. torch.save(state_dict, output)
  60. logger.info(f"Merged model saved to {output}")
  61. if __name__ == "__main__":
  62. merge()