models.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import torch
  2. from torch import nn
  3. from fish_speech.models.vqgan.modules.decoder import Generator
  4. from fish_speech.models.vqgan.modules.encoders import (
  5. PosteriorEncoder,
  6. SpeakerEncoder,
  7. TextEncoder,
  8. )
  9. from fish_speech.models.vqgan.utils import rand_slice_segments
  10. class SynthesizerTrn(nn.Module):
  11. """
  12. Synthesizer for Training
  13. """
  14. def __init__(
  15. self,
  16. in_channels,
  17. spec_channels,
  18. segment_size,
  19. inter_channels,
  20. hidden_channels,
  21. filter_channels,
  22. n_heads,
  23. n_layers,
  24. n_layers_q,
  25. n_layers_spk,
  26. kernel_size,
  27. p_dropout,
  28. speaker_cond_layer,
  29. resblock,
  30. resblock_kernel_sizes,
  31. resblock_dilation_sizes,
  32. upsample_rates,
  33. upsample_initial_channel,
  34. upsample_kernel_sizes,
  35. gin_channels,
  36. ):
  37. super().__init__()
  38. self.segment_size = segment_size
  39. self.enc_p = TextEncoder(
  40. in_channels,
  41. inter_channels,
  42. hidden_channels,
  43. filter_channels,
  44. n_heads,
  45. n_layers,
  46. kernel_size,
  47. p_dropout,
  48. gin_channels=gin_channels,
  49. speaker_cond_layer=speaker_cond_layer,
  50. )
  51. self.enc_spk = SpeakerEncoder(
  52. in_channels=spec_channels,
  53. hidden_channels=inter_channels,
  54. out_channels=gin_channels,
  55. num_heads=n_heads,
  56. num_layers=n_layers_spk,
  57. p_dropout=p_dropout,
  58. )
  59. self.enc_q = PosteriorEncoder(
  60. spec_channels,
  61. inter_channels,
  62. hidden_channels,
  63. 5,
  64. 1,
  65. n_layers_q,
  66. gin_channels=gin_channels,
  67. )
  68. self.dec = Generator(
  69. inter_channels,
  70. resblock,
  71. resblock_kernel_sizes,
  72. resblock_dilation_sizes,
  73. upsample_rates,
  74. upsample_initial_channel,
  75. upsample_kernel_sizes,
  76. gin_channels=gin_channels,
  77. )
  78. def forward(self, x, x_lengths, y):
  79. g = self.enc_spk(y, x_lengths)
  80. z_p, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
  81. z_q, m_q, logs_q, y_mask = self.enc_q(y, x_lengths, g=g)
  82. z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
  83. o = self.dec(z_slice, g=g)
  84. return (
  85. o,
  86. ids_slice,
  87. x_mask,
  88. y_mask,
  89. (z_q, z_p),
  90. (m_p, logs_p),
  91. (m_q, logs_q),
  92. )
  93. def infer(self, x, x_lengths, y, max_len=None):
  94. g = self.enc_spk(y, x_lengths)
  95. z_p, m_p, logs_p, h_text, x_mask = self.enc_p(x, x_lengths, g=g)
  96. # z_p_audio, m_p_audio, logs_p_audio = self.flow(z_p_text, m_p_text, logs_p_text, x_mask, g=g, reverse=True)
  97. o = self.dec((z_p * x_mask)[:, :, :max_len], g=g)
  98. return o
  99. def reconstruct(self, x, x_lengths, max_len=None):
  100. g = self.enc_spk(x, x_lengths)
  101. z_q, m_q, logs_q, x_mask = self.enc_q(x, x_lengths, g=g)
  102. o = self.dec((z_q * x_mask)[:, :, :max_len], g=g)
  103. return o