merge_lora.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import click
  2. import hydra
  3. import torch
  4. from hydra import compose, initialize
  5. from hydra.utils import instantiate
  6. from loguru import logger
  7. from fish_speech.models.text2semantic.lora_utils import (
  8. get_merged_state_dict,
  9. setup_lora,
  10. )
  11. @click.command()
  12. @click.option("--llama-config", type=str, default="dual_ar_2_codebook_medium")
  13. @click.option("--lora-config", type=str, default="r_8_alpha_16")
  14. @click.option(
  15. "--llama-weight", type=str, default="checkpoints/text2semantic-sft-medium-v1-4k.pth"
  16. )
  17. @click.option("--lora-weight", type=str, required=True)
  18. @click.option("--output", type=str, required=True)
  19. def merge(llama_config, lora_config, llama_weight, lora_weight, output):
  20. logger.info(
  21. f"Merging {llama_weight} and {lora_weight} into {output} with configs {llama_config} and {lora_config}"
  22. )
  23. hydra.core.global_hydra.GlobalHydra.instance().clear()
  24. with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
  25. # The max_seq_len here doesn't matter.
  26. cfg = compose(config_name=llama_config, overrides=[f"config.max_seq_len=2048"])
  27. llama_model = instantiate(cfg)
  28. logger.info(f"Loaded llama model with config {llama_config}")
  29. hydra.core.global_hydra.GlobalHydra.instance().clear()
  30. with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
  31. cfg = compose(config_name=lora_config)
  32. lora_config = instantiate(cfg)
  33. logger.info(f"Loaded lora model with config {lora_config}")
  34. setup_lora(llama_model, lora_config)
  35. logger.info(f"Merged model setup complete")
  36. llama_state_dict = torch.load(llama_weight, map_location="cpu")
  37. lora_state_dict = torch.load(lora_weight, map_location="cpu")
  38. if "state_dict" in llama_state_dict:
  39. llama_state_dict = llama_state_dict["state_dict"]
  40. if "state_dict" in lora_state_dict:
  41. lora_state_dict = lora_state_dict["state_dict"]
  42. # remove prefix model.
  43. if any(k.startswith("model.") for k in llama_state_dict.keys()):
  44. llama_state_dict = {
  45. k.replace("model.", ""): v
  46. for k, v in llama_state_dict.items()
  47. if k.startswith("model.")
  48. }
  49. if any(k.startswith("model.") for k in lora_state_dict.keys()):
  50. lora_state_dict = {
  51. k.replace("model.", ""): v
  52. for k, v in lora_state_dict.items()
  53. if k.startswith("model.")
  54. }
  55. logger.info(f"Found {len(llama_state_dict)} keys in llama model")
  56. logger.info(f"Found {len(lora_state_dict)} keys in lora model")
  57. merged_state_dict = llama_state_dict | lora_state_dict
  58. llama_model.load_state_dict(merged_state_dict, strict=True)
  59. logger.info(f"Merged model loaded")
  60. state_dict = get_merged_state_dict(llama_model)
  61. torch.save(state_dict, output)
  62. logger.info(f"Merged model saved to {output}")
  63. if __name__ == "__main__":
  64. merge()