| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- import shutil
- from copy import deepcopy
- from pathlib import Path
- import click
- import hydra
- import torch
- from hydra import compose, initialize
- from hydra.utils import instantiate
- from loguru import logger
- from fish_speech.models.text2semantic.llama import BaseTransformer
- from fish_speech.models.text2semantic.lora import get_merged_state_dict
- @click.command()
- @click.option("--lora-config", type=str, default="r_8_alpha_16")
- @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.2-sft")
- @click.option("--lora-weight", type=str, required=True)
- @click.option("--output", type=str, required=True)
- def merge(lora_config, base_weight, lora_weight, output):
- output = Path(output)
- logger.info(
- f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
- )
- with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
- cfg = compose(config_name=lora_config)
- lora_config = instantiate(cfg)
- logger.info(f"Loaded lora model with config {lora_config}")
- llama_model = BaseTransformer.from_pretrained(
- path=base_weight,
- load_weights=True,
- lora_config=lora_config,
- )
- logger.info(f"Loaded llama model")
- llama_state_dict = llama_model.state_dict()
- llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
- llama_state_dict_copy = deepcopy(llama_state_dict)
- lora_state_dict = torch.load(lora_weight, map_location="cpu")
- if "state_dict" in llama_state_dict:
- llama_state_dict = llama_state_dict["state_dict"]
- if "state_dict" in lora_state_dict:
- lora_state_dict = lora_state_dict["state_dict"]
- # remove prefix model.
- if any(k.startswith("model.") for k in llama_state_dict.keys()):
- llama_state_dict = {
- k.replace("model.", ""): v
- for k, v in llama_state_dict.items()
- if k.startswith("model.")
- }
- if any(k.startswith("model.") for k in lora_state_dict.keys()):
- lora_state_dict = {
- k.replace("model.", ""): v
- for k, v in lora_state_dict.items()
- if k.startswith("model.")
- }
- logger.info(f"Found {len(llama_state_dict)} keys in llama model")
- logger.info(f"Found {len(lora_state_dict)} keys in lora model")
- merged_state_dict = llama_state_dict | lora_state_dict
- llama_model.load_state_dict(merged_state_dict, strict=True)
- logger.info(f"Merged model loaded")
- # Trigger eval mode to merge lora
- llama_model.eval()
- llama_model.save_pretrained(output, drop_lora=True)
- logger.info(f"Saved merged model to {output}, validating")
- new_state_dict = torch.load(output / "model.pth", map_location="cpu")
- original_keys = set(llama_state_dict_copy.keys())
- merged_keys = set(new_state_dict.keys())
- assert original_keys == merged_keys, "Keys should be same"
- for key in original_keys:
- diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
- if diff_l1 != 0:
- break
- else:
- logger.error("Merged model is same as the original model")
- exit(1)
- logger.info("Merged model is different from the original model, check passed")
- if __name__ == "__main__":
- merge()
|