| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- import click
- import hydra
- import torch
- from hydra import compose, initialize
- from hydra.utils import instantiate
- from loguru import logger
- from fish_speech.models.text2semantic.lora_utils import (
- get_merged_state_dict,
- setup_lora,
- )
- @click.command()
- @click.option("--llama-config", type=str, default="dual_ar_2_codebook_medium")
- @click.option("--lora-config", type=str, default="r_8_alpha_16")
- @click.option(
- "--llama-weight", type=str, default="checkpoints/text2semantic-sft-medium-v1-4k.pth"
- )
- @click.option("--lora-weight", type=str, required=True)
- @click.option("--output", type=str, required=True)
- def merge(llama_config, lora_config, llama_weight, lora_weight, output):
- logger.info(
- f"Merging {llama_weight} and {lora_weight} into {output} with configs {llama_config} and {lora_config}"
- )
- hydra.core.global_hydra.GlobalHydra.instance().clear()
- with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
- # The max_seq_len here doesn't matter.
- cfg = compose(config_name=llama_config, overrides=[f"config.max_seq_len=2048"])
- llama_model = instantiate(cfg)
- logger.info(f"Loaded llama model with config {llama_config}")
- hydra.core.global_hydra.GlobalHydra.instance().clear()
- with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
- cfg = compose(config_name=lora_config)
- lora_config = instantiate(cfg)
- logger.info(f"Loaded lora model with config {lora_config}")
- setup_lora(llama_model, lora_config)
- logger.info(f"Merged model setup complete")
- llama_state_dict = torch.load(llama_weight, map_location="cpu")
- lora_state_dict = torch.load(lora_weight, map_location="cpu")
- if "state_dict" in llama_state_dict:
- llama_state_dict = llama_state_dict["state_dict"]
- if "state_dict" in lora_state_dict:
- lora_state_dict = lora_state_dict["state_dict"]
- # remove prefix model.
- if any(k.startswith("model.") for k in llama_state_dict.keys()):
- llama_state_dict = {
- k.replace("model.", ""): v
- for k, v in llama_state_dict.items()
- if k.startswith("model.")
- }
- if any(k.startswith("model.") for k in lora_state_dict.keys()):
- lora_state_dict = {
- k.replace("model.", ""): v
- for k, v in lora_state_dict.items()
- if k.startswith("model.")
- }
- logger.info(f"Found {len(llama_state_dict)} keys in llama model")
- logger.info(f"Found {len(lora_state_dict)} keys in lora model")
- merged_state_dict = llama_state_dict | lora_state_dict
- llama_model.load_state_dict(merged_state_dict, strict=True)
- logger.info(f"Merged model loaded")
- state_dict = get_merged_state_dict(llama_model)
- torch.save(state_dict, output)
- logger.info(f"Merged model saved to {output}")
- if __name__ == "__main__":
- merge()
|