infer_vq.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import librosa
  2. import numpy as np
  3. import soundfile as sf
  4. import torch
  5. import torch.nn.functional as F
  6. from einops import rearrange
  7. from hydra import compose, initialize
  8. from hydra.utils import instantiate
  9. from lightning import LightningModule
  10. from loguru import logger
  11. from omegaconf import OmegaConf
  12. from fish_speech.models.vqgan.utils import sequence_mask
  13. # register eval resolver
  14. OmegaConf.register_new_resolver("eval", eval)
  15. @torch.no_grad()
  16. @torch.autocast(device_type="cuda", enabled=True)
  17. def main():
  18. with initialize(version_base="1.3", config_path="../fish_speech/configs"):
  19. cfg = compose(config_name="vqgan_single_2x")
  20. model: LightningModule = instantiate(cfg.model)
  21. state_dict = torch.load(
  22. "results/vqgan_single_2x/checkpoints/step_000160000.ckpt",
  23. map_location=model.device,
  24. )["state_dict"]
  25. model.load_state_dict(state_dict, strict=True)
  26. model.eval()
  27. model.cuda()
  28. logger.info("Restored model from checkpoint")
  29. # Load audio
  30. audio = librosa.load(
  31. "data/StarRail/Chinese/停云/chapter2_1_tingyun_142.wav",
  32. sr=model.sampling_rate,
  33. mono=True,
  34. )[0]
  35. audios = torch.from_numpy(audio).to(model.device)[None, None, :]
  36. logger.info(
  37. f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
  38. )
  39. # VQ Encoder
  40. audio_lengths = torch.tensor(
  41. [audios.shape[2]], device=model.device, dtype=torch.long
  42. )
  43. features = gt_mels = model.mel_transform(audios, sample_rate=model.sampling_rate)
  44. if model.downsample is not None:
  45. features = model.downsample(features)
  46. mel_lengths = audio_lengths // model.hop_length
  47. feature_lengths = (
  48. audio_lengths
  49. / model.hop_length
  50. / (model.downsample.total_strides if model.downsample is not None else 1)
  51. ).long()
  52. feature_masks = torch.unsqueeze(
  53. sequence_mask(feature_lengths, features.shape[2]), 1
  54. ).to(gt_mels.dtype)
  55. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  56. gt_mels.dtype
  57. )
  58. # vq_features is 50 hz, need to convert to true mel size
  59. text_features = model.mel_encoder(features, feature_masks)
  60. _, indices, _ = model.vq_encoder(text_features, feature_masks)
  61. print(indices.shape)
  62. # Restore
  63. # indices = np.load("codes_0.npy")
  64. # indices = torch.from_numpy(indices).to(model.device).long()
  65. # indices = indices.unsqueeze(1).unsqueeze(-1)
  66. # mel_lengths = indices.shape[2] * (
  67. # model.downsample.total_strides if model.downsample is not None else 1
  68. # )
  69. # mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
  70. # mel_masks = torch.ones(
  71. # (1, 1, mel_lengths), device=model.device, dtype=torch.float32
  72. # )
  73. # print(mel_lengths)
  74. # Reference speaker
  75. ref_audio = librosa.load(
  76. "data/StarRail/Chinese/符玄/chapter2_8_fuxuan_104.wav",
  77. sr=model.sampling_rate,
  78. mono=True,
  79. )[0]
  80. ref_audios = torch.from_numpy(ref_audio).to(model.device)[None, None, :]
  81. ref_audio_lengths = torch.tensor(
  82. [ref_audios.shape[2]], device=model.device, dtype=torch.long
  83. )
  84. ref_mels = model.mel_transform(ref_audios, sample_rate=model.sampling_rate)
  85. ref_mel_lengths = ref_audio_lengths // model.hop_length
  86. ref_mel_masks = torch.unsqueeze(
  87. sequence_mask(ref_mel_lengths, ref_mels.shape[2]), 1
  88. ).to(gt_mels.dtype)
  89. speaker_features = model.speaker_encoder(ref_mels, ref_mel_masks)
  90. # speaker_features = model.speaker_encoder(gt_mels, mel_masks)
  91. print("indices", indices.shape)
  92. text_features = model.vq_encoder.decode(indices)
  93. logger.info(
  94. f"VQ Encoded, indices: {indices.shape} equivalent to "
  95. + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[1]):.2f} Hz"
  96. )
  97. text_features = F.interpolate(text_features, size=mel_lengths[0], mode="nearest")
  98. # Sample mels
  99. decoded_mels = model.decoder(text_features, mel_masks, g=speaker_features)
  100. fake_audios = model.generator(decoded_mels)
  101. # Save audio
  102. fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
  103. sf.write("fake.wav", fake_audio, model.sampling_rate)
  104. if __name__ == "__main__":
  105. main()