merge_lora.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import shutil
  2. from copy import deepcopy
  3. from pathlib import Path
  4. import click
  5. import hydra
  6. import torch
  7. from hydra import compose, initialize
  8. from hydra.utils import instantiate
  9. from loguru import logger
  10. from fish_speech.models.text2semantic.llama import BaseTransformer
  11. from fish_speech.models.text2semantic.lora import get_merged_state_dict
  12. @click.command()
  13. @click.option("--lora-config", type=str, default="r_8_alpha_16")
  14. @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
  15. @click.option("--lora-weight", type=str, required=True)
  16. @click.option("--output", type=str, required=True)
  17. def merge(lora_config, base_weight, lora_weight, output):
  18. output = Path(output)
  19. logger.info(
  20. f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
  21. )
  22. with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
  23. cfg = compose(config_name=lora_config)
  24. lora_config = instantiate(cfg)
  25. logger.info(f"Loaded lora model with config {lora_config}")
  26. llama_model = BaseTransformer.from_pretrained(
  27. path=base_weight,
  28. load_weights=True,
  29. lora_config=lora_config,
  30. )
  31. logger.info(f"Loaded llama model")
  32. llama_state_dict = llama_model.state_dict()
  33. llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
  34. llama_state_dict_copy = deepcopy(llama_state_dict)
  35. lora_state_dict = torch.load(lora_weight, map_location="cpu", weights_only=False)
  36. if "state_dict" in llama_state_dict:
  37. llama_state_dict = llama_state_dict["state_dict"]
  38. if "state_dict" in lora_state_dict:
  39. lora_state_dict = lora_state_dict["state_dict"]
  40. # remove prefix model.
  41. if any(k.startswith("model.") for k in llama_state_dict.keys()):
  42. llama_state_dict = {
  43. k.replace("model.", ""): v
  44. for k, v in llama_state_dict.items()
  45. if k.startswith("model.")
  46. }
  47. if any(k.startswith("model.") for k in lora_state_dict.keys()):
  48. lora_state_dict = {
  49. k.replace("model.", ""): v
  50. for k, v in lora_state_dict.items()
  51. if k.startswith("model.")
  52. }
  53. logger.info(f"Found {len(llama_state_dict)} keys in llama model")
  54. logger.info(f"Found {len(lora_state_dict)} keys in lora model")
  55. merged_state_dict = llama_state_dict | lora_state_dict
  56. llama_model.load_state_dict(merged_state_dict, strict=True)
  57. logger.info(f"Merged model loaded")
  58. # Trigger eval mode to merge lora
  59. llama_model.eval()
  60. llama_model.save_pretrained(output, drop_lora=True)
  61. logger.info(f"Saved merged model to {output}, validating")
  62. new_state_dict = torch.load(output / "model.pth", map_location="cpu")
  63. original_keys = set(llama_state_dict_copy.keys())
  64. tolerance = 1e-5
  65. for key in original_keys:
  66. diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
  67. if diff_l1 > tolerance:
  68. logger.info(f"Significant difference found in key: {key}")
  69. break
  70. if diff_l1 <= tolerance:
  71. logger.warning(
  72. "Merged model seems identical to the original model. Further validation might be needed."
  73. )
  74. else:
  75. logger.info("Merged model is different from the original model, check passed")
  76. if __name__ == "__main__":
  77. merge()