infer_vq.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import librosa
  2. import numpy as np
  3. import soundfile as sf
  4. import torch
  5. import torch.nn.functional as F
  6. from einops import rearrange
  7. from hydra import compose, initialize
  8. from hydra.utils import instantiate
  9. from lightning import LightningModule
  10. from loguru import logger
  11. from omegaconf import OmegaConf
  12. from fish_speech.models.vqgan.utils import sequence_mask
  13. # register eval resolver
  14. OmegaConf.register_new_resolver("eval", eval)
  15. @torch.no_grad()
  16. @torch.autocast(device_type="cuda", enabled=True)
  17. def main():
  18. with initialize(version_base="1.3", config_path="../fish_speech/configs"):
  19. cfg = compose(config_name="vqgan")
  20. model: LightningModule = instantiate(cfg.model)
  21. state_dict = torch.load(
  22. "checkpoints/vqgan/step_000380000.ckpt",
  23. map_location=model.device,
  24. )["state_dict"]
  25. model.load_state_dict(state_dict, strict=True)
  26. model.eval()
  27. model.cuda()
  28. logger.info("Restored model from checkpoint")
  29. # Load audio
  30. audio = librosa.load("test.wav", sr=model.sampling_rate, mono=True)[0]
  31. audios = torch.from_numpy(audio).to(model.device)[None, None, :]
  32. logger.info(
  33. f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
  34. )
  35. # VQ Encoder
  36. audio_lengths = torch.tensor(
  37. [audios.shape[2]], device=model.device, dtype=torch.long
  38. )
  39. features = gt_mels = model.mel_transform(audios, sample_rate=model.sampling_rate)
  40. if model.downsample is not None:
  41. features = model.downsample(features)
  42. mel_lengths = audio_lengths // model.hop_length
  43. feature_lengths = (
  44. audio_lengths
  45. / model.hop_length
  46. / (model.downsample.total_strides if model.downsample is not None else 1)
  47. ).long()
  48. feature_masks = torch.unsqueeze(
  49. sequence_mask(feature_lengths, features.shape[2]), 1
  50. ).to(gt_mels.dtype)
  51. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  52. gt_mels.dtype
  53. )
  54. # vq_features is 50 hz, need to convert to true mel size
  55. text_features = model.mel_encoder(features, feature_masks)
  56. _, indices, _ = model.vq_encoder(text_features, feature_masks)
  57. print(indices.shape)
  58. # Restore
  59. indices = np.load("codes_0.npy")
  60. indices = torch.from_numpy(indices).to(model.device).long()
  61. indices = indices.unsqueeze(1).unsqueeze(-1)
  62. mel_lengths = indices.shape[2] * (
  63. model.downsample.total_strides if model.downsample is not None else 1
  64. )
  65. mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
  66. mel_masks = torch.ones(
  67. (1, 1, mel_lengths), device=model.device, dtype=torch.float32
  68. )
  69. print(mel_lengths)
  70. text_features = model.vq_encoder.decode(indices)
  71. logger.info(
  72. f"VQ Encoded, indices: {indices.shape} equivalent to "
  73. + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
  74. )
  75. text_features = F.interpolate(text_features, size=mel_lengths[0], mode="nearest")
  76. # Sample mels
  77. decoded_mels = model.decoder(text_features, mel_masks)
  78. fake_audios = model.generator(decoded_mels)
  79. # Save audio
  80. fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
  81. sf.write("fake.wav", fake_audio, model.sampling_rate)
  82. if __name__ == "__main__":
  83. main()