vq_encoder.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import torch
  2. from torch import nn
  3. from fish_speech.models.vqgan.modules.fsq import DownsampleFiniteScalarQuantize
  4. from fish_speech.models.vqgan.modules.wavenet import WaveNet
  5. from fish_speech.utils.spectrogram import LogMelSpectrogram
  6. class VQEncoder(nn.Module):
  7. def __init__(
  8. self,
  9. ):
  10. super().__init__()
  11. self.encoder = WaveNet(
  12. input_channels=128,
  13. residual_channels=768,
  14. residual_layers=20,
  15. dilation_cycle=4,
  16. )
  17. self.quantizer = DownsampleFiniteScalarQuantize(
  18. input_dim=768, n_codebooks=1, n_groups=2, levels=[8, 5, 5, 5]
  19. )
  20. self.spec = LogMelSpectrogram(
  21. sample_rate=44100,
  22. n_fft=2048,
  23. win_length=2048,
  24. hop_length=512,
  25. n_mels=128,
  26. f_min=0.0,
  27. f_max=8000.0,
  28. )
  29. self.eval()
  30. e = self.load_state_dict(
  31. torch.load("checkpoints/vq-gan-group-fsq-2x1024.pth", map_location="cpu"),
  32. strict=False,
  33. )
  34. assert len(e.missing_keys) == 0, e.missing_keys
  35. assert all(
  36. k.startswith("decoder.")
  37. or k.startswith("quality_projection.")
  38. or k.startswith("discriminator.")
  39. for k in e.unexpected_keys
  40. ), e.unexpected_keys
  41. @torch.no_grad()
  42. def forward(self, audios, audio_lengths, use_decoder=False, sr=None):
  43. mel_spec = self.spec(audios, sample_rate=sr)
  44. if sr is not None:
  45. audio_lengths = audio_lengths * 44100 // sr
  46. mel_lengths = audio_lengths // self.spec.hop_length
  47. mel_masks = (
  48. torch.arange(mel_spec.shape[2], device=mel_spec.device)
  49. < mel_lengths[:, None]
  50. )
  51. mel_masks_float_conv = mel_masks[:, None, :].float()
  52. mels = mel_spec * mel_masks_float_conv
  53. # Encode
  54. encoded_features = self.encoder(mels) * mel_masks_float_conv
  55. encoded_features = self.quantizer(encoded_features).z * mel_masks_float_conv
  56. return encoded_features