adamos.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import librosa
  2. import torch
  3. from torch import nn
  4. from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
  5. from .encoder import ConvNeXtEncoder
  6. from .hifigan import HiFiGANGenerator
  7. class ADaMoSHiFiGANV1(nn.Module):
  8. def __init__(
  9. self,
  10. checkpoint_path: str = "checkpoints/adamos-generator-1640000.pth",
  11. ):
  12. super().__init__()
  13. self.backbone = ConvNeXtEncoder(
  14. input_channels=128,
  15. depths=[3, 3, 9, 3],
  16. dims=[128, 256, 384, 512],
  17. drop_path_rate=0,
  18. kernel_sizes=(7,),
  19. )
  20. self.head = HiFiGANGenerator(
  21. hop_length=512,
  22. upsample_rates=(4, 4, 2, 2, 2, 2, 2),
  23. upsample_kernel_sizes=(8, 8, 4, 4, 4, 4, 4),
  24. resblock_kernel_sizes=(3, 7, 11, 13),
  25. resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
  26. num_mels=512,
  27. upsample_initial_channel=1024,
  28. use_template=False,
  29. pre_conv_kernel_size=13,
  30. post_conv_kernel_size=13,
  31. )
  32. self.sampling_rate = 44100
  33. ckpt_state = torch.load(checkpoint_path, map_location="cpu")
  34. if "state_dict" in ckpt_state:
  35. ckpt_state = ckpt_state["state_dict"]
  36. if any(k.startswith("generator.") for k in ckpt_state):
  37. ckpt_state = {
  38. k.replace("generator.", ""): v
  39. for k, v in ckpt_state.items()
  40. if k.startswith("generator.")
  41. }
  42. self.load_state_dict(ckpt_state)
  43. self.eval()
  44. self.mel_transform = LogMelSpectrogram(
  45. sample_rate=44100,
  46. n_fft=2048,
  47. win_length=2048,
  48. hop_length=512,
  49. f_min=40,
  50. f_max=16000,
  51. n_mels=128,
  52. )
  53. @torch.no_grad()
  54. def decode(self, mel):
  55. y = self.backbone(mel)
  56. y = self.head(y)
  57. return y
  58. @torch.no_grad()
  59. def encode(self, x):
  60. return self.mel_transform(x)
  61. if __name__ == "__main__":
  62. import soundfile as sf
  63. x = "data/StarRail/Chinese/罗刹/archive_luocha_2.wav"
  64. model = ADaMoSHiFiGANV1()
  65. wav, sr = librosa.load(x, sr=44100, mono=True)
  66. wav = torch.from_numpy(wav).float()[None]
  67. mel = model.encode(wav)
  68. wav = model.decode(mel)[0].mT
  69. sf.write("test.wav", wav.cpu().numpy(), 44100)