inference.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. from pathlib import Path
  2. import click
  3. import librosa
  4. import numpy as np
  5. import soundfile as sf
  6. import torch
  7. import torch.nn.functional as F
  8. from einops import rearrange
  9. from hydra import compose, initialize
  10. from hydra.utils import instantiate
  11. from lightning import LightningModule
  12. from loguru import logger
  13. from omegaconf import OmegaConf
  14. from fish_speech.models.vqgan.utils import sequence_mask
  15. # register eval resolver
  16. OmegaConf.register_new_resolver("eval", eval)
  17. @torch.no_grad()
  18. @torch.autocast(device_type="cuda", enabled=True)
  19. @click.command()
  20. @click.option(
  21. "--input-path",
  22. "-i",
  23. default="data/Genshin/Chinese/派蒙/vo_WYLQ103_10_paimon_04.wav",
  24. type=click.Path(exists=True, path_type=Path),
  25. )
  26. @click.option(
  27. "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
  28. )
  29. @click.option("--config-name", "-cfg", default="vqgan_pretrain")
  30. @click.option(
  31. "--checkpoint-path", "-ckpt", default="checkpoints/vqgan/step_000380000_wo.ckpt"
  32. )
  33. def main(input_path, output_path, config_name, checkpoint_path):
  34. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  35. cfg = compose(config_name=config_name)
  36. model: LightningModule = instantiate(cfg.model)
  37. state_dict = torch.load(
  38. checkpoint_path,
  39. map_location=model.device,
  40. )
  41. if "state_dict" in state_dict:
  42. state_dict = state_dict["state_dict"]
  43. model.load_state_dict(state_dict, strict=True)
  44. model.eval()
  45. model.cuda()
  46. logger.info("Restored model from checkpoint")
  47. if input_path.suffix == ".wav":
  48. logger.info(f"Processing in-place reconstruction of {input_path}")
  49. # Load audio
  50. audio, _ = librosa.load(
  51. input_path,
  52. sr=model.sampling_rate,
  53. mono=True,
  54. )
  55. audios = torch.from_numpy(audio).to(model.device)[None, None, :]
  56. logger.info(
  57. f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
  58. )
  59. # VQ Encoder
  60. audio_lengths = torch.tensor(
  61. [audios.shape[2]], device=model.device, dtype=torch.long
  62. )
  63. features = gt_mels = model.mel_transform(
  64. audios, sample_rate=model.sampling_rate
  65. )
  66. if model.downsample is not None:
  67. features = model.downsample(features)
  68. mel_lengths = audio_lengths // model.hop_length
  69. feature_lengths = (
  70. audio_lengths
  71. / model.hop_length
  72. / (model.downsample.total_strides if model.downsample is not None else 1)
  73. ).long()
  74. feature_masks = torch.unsqueeze(
  75. sequence_mask(feature_lengths, features.shape[2]), 1
  76. ).to(gt_mels.dtype)
  77. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  78. gt_mels.dtype
  79. )
  80. # vq_features is 50 hz, need to convert to true mel size
  81. text_features = model.mel_encoder(features, feature_masks)
  82. _, indices, _ = model.vq_encoder(text_features, feature_masks)
  83. if indices.ndim == 4 and indices.shape[1] == 1 and indices.shape[3] == 1:
  84. indices = indices[:, 0, :, 0]
  85. else:
  86. logger.error(f"Unknown indices shape: {indices.shape}")
  87. return
  88. logger.info(f"Generated indices of shape {indices.shape}")
  89. # Save indices
  90. np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
  91. elif input_path.suffix == ".npy":
  92. logger.info(f"Processing precomputed indices from {input_path}")
  93. indices = np.load(input_path)
  94. indices = torch.from_numpy(indices).to(model.device).long()
  95. assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
  96. else:
  97. raise ValueError(f"Unknown input type: {input_path}")
  98. # Restore
  99. indices = indices.unsqueeze(1).unsqueeze(-1)
  100. mel_lengths = indices.shape[2] * (
  101. model.downsample.total_strides if model.downsample is not None else 1
  102. )
  103. mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
  104. mel_masks = torch.ones(
  105. (1, 1, mel_lengths), device=model.device, dtype=torch.float32
  106. )
  107. text_features = model.vq_encoder.decode(indices)
  108. logger.info(
  109. f"VQ Encoded, indices: {indices.shape} equivalent to "
  110. + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
  111. )
  112. text_features = F.interpolate(text_features, size=mel_lengths[0], mode="nearest")
  113. # Sample mels
  114. decoded_mels = model.decoder(text_features, mel_masks)
  115. fake_audios = model.generator(decoded_mels)
  116. logger.info(
  117. f"Generated audio of shape {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
  118. )
  119. # Save audio
  120. fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
  121. sf.write("fake.wav", fake_audio, model.sampling_rate)
  122. logger.info(f"Saved audio to {output_path}")
  123. if __name__ == "__main__":
  124. main()