vq_encoder.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import math
  2. import torch
  3. from torch import nn
  4. from fish_speech.models.vqgan.modules.fsq import DownsampleFiniteScalarQuantize
  5. from fish_speech.models.vqgan.modules.wavenet import WaveNet
  6. from fish_speech.models.vqgan.utils import sequence_mask
  7. from fish_speech.utils.spectrogram import LogMelSpectrogram
  8. class VQEncoder(nn.Module):
  9. def __init__(
  10. self,
  11. ):
  12. super().__init__()
  13. self.encoder = WaveNet(
  14. input_channels=128,
  15. residual_channels=768,
  16. residual_layers=20,
  17. dilation_cycle=4,
  18. )
  19. self.quantizer = DownsampleFiniteScalarQuantize(
  20. input_dim=768, n_codebooks=1, n_groups=2, levels=[8, 5, 5, 5]
  21. )
  22. self.spec = LogMelSpectrogram(
  23. sample_rate=44100,
  24. n_fft=2048,
  25. win_length=2048,
  26. hop_length=512,
  27. n_mels=128,
  28. f_min=0.0,
  29. f_max=8000.0,
  30. )
  31. self.eval()
  32. e = self.load_state_dict(
  33. torch.load("checkpoints/vq-gan-group-fsq-2x1024.pth", map_location="cpu"),
  34. strict=False,
  35. )
  36. assert len(e.missing_keys) == 0, e.missing_keys
  37. assert all(
  38. k.startswith("decoder.")
  39. or k.startswith("quality_projection.")
  40. or k.startswith("discriminator.")
  41. for k in e.unexpected_keys
  42. ), e.unexpected_keys
  43. @torch.no_grad()
  44. def forward(self, audios, audio_lengths, sr=None):
  45. mel_spec = self.spec(audios, sample_rate=sr)
  46. if sr is not None:
  47. audio_lengths = audio_lengths * 44100 // sr
  48. mel_lengths = audio_lengths // self.spec.hop_length
  49. mel_masks = (
  50. torch.arange(mel_spec.shape[2], device=mel_spec.device)
  51. < mel_lengths[:, None]
  52. )
  53. mel_masks_float_conv = mel_masks[:, None, :].float()
  54. mels = mel_spec * mel_masks_float_conv
  55. # Encode
  56. encoded_features = self.encoder(mels) * mel_masks_float_conv
  57. encoded_features = self.quantizer(encoded_features).z * mel_masks_float_conv
  58. return encoded_features
  59. @torch.no_grad()
  60. def indicies_to_vq_features(
  61. self,
  62. indices,
  63. feature_lengths,
  64. ):
  65. factor = math.prod(self.quantizer.downsample_factor)
  66. mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
  67. mel_masks_float_conv = mel_masks[:, None, :].float()
  68. z = self.quantizer.decode(indices) * mel_masks_float_conv
  69. return z
  70. @torch.no_grad()
  71. def encode(self, audios, audio_lengths, sr=None):
  72. audios = audios.float()
  73. mels = self.spec(audios, sample_rate=sr)
  74. mel_lengths = audio_lengths // self.spec.hop_length
  75. mel_masks = sequence_mask(mel_lengths, mels.shape[2])
  76. mel_masks_float_conv = mel_masks[:, None, :].float()
  77. mels = mels * mel_masks_float_conv
  78. # Encode
  79. encoded_features = self.encoder(mels) * mel_masks_float_conv
  80. feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
  81. return self.quantizer.encode(encoded_features), feature_lengths