migrate_from_vits.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import hydra
  2. import torch
  3. from loguru import logger
  4. from omegaconf import DictConfig, OmegaConf
  5. # register eval resolver
  6. OmegaConf.register_new_resolver("eval", eval)
  7. @hydra.main(
  8. version_base="1.3",
  9. config_path="../../fish_speech/configs",
  10. config_name="hubert_vq.yaml",
  11. )
  12. def main(cfg: DictConfig):
  13. generator_ckpt = cfg.get(
  14. "generator_ckpt", "results/hubert-vq-pretrain/rcell/G_23000.pth"
  15. )
  16. discriminator_ckpt = cfg.get(
  17. "discriminator_ckpt", "results/hubert-vq-pretrain/rcell/D_23000.pth"
  18. )
  19. model = hydra.utils.instantiate(cfg.model)
  20. # Generator
  21. logger.info(f"Model loaded, restoring from {generator_ckpt}")
  22. generator_weights = torch.load(generator_ckpt, map_location="cpu")["model"]
  23. # Decoder
  24. generator_state = {
  25. k[4:]: v
  26. for k, v in generator_weights.items()
  27. if k.startswith("dec.") and not k.startswith("dec.cond.")
  28. }
  29. logger.info(f"Found {len(generator_state)} HiFiGAN weights, restoring...")
  30. r = model.generator.dec.load_state_dict(generator_state, strict=False)
  31. logger.info(f"Generator weights restored. {r}")
  32. # Posterior Encoder
  33. # encoder_state = {
  34. # k[6:]: v
  35. # for k, v in generator_weights.items()
  36. # if k.startswith("enc_q.") and not k.startswith("enc_q.proj.")
  37. # }
  38. # logger.info(f"Found {len(encoder_state)} posterior encoder weights, restoring...")
  39. # x = model.generator.enc_q.load_state_dict(encoder_state, strict=False)
  40. # logger.info(f"Posterior encoder weights restored. {x}")
  41. # Flow
  42. # flow_state = {
  43. # k[5:]: v for k, v in generator_weights.items() if k.startswith("flow.")
  44. # }
  45. # logger.info(f"Found {len(flow_state)} flow weights, restoring...")
  46. # model.flow.load_state_dict(flow_state, strict=True)
  47. # logger.info("Flow weights restored.")
  48. # Discriminator
  49. logger.info(f"Model loaded, restoring from {discriminator_ckpt}")
  50. discriminator_weights = torch.load(discriminator_ckpt, map_location="cpu")["model"]
  51. logger.info(
  52. f"Found {len(discriminator_weights)} discriminator weights, restoring..."
  53. )
  54. model.discriminator.load_state_dict(discriminator_weights, strict=True)
  55. logger.info("Discriminator weights restored.")
  56. torch.save(model.state_dict(), cfg.ckpt_path)
  57. logger.info("Done")
  58. if __name__ == "__main__":
  59. main()