models.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import torch
  2. from torch import nn
  3. from fish_speech.models.vqgan.modules.decoder import Generator
  4. from fish_speech.models.vqgan.modules.encoders import (
  5. PosteriorEncoder,
  6. SpeakerEncoder,
  7. TextEncoder,
  8. VQEncoder,
  9. )
  10. from fish_speech.models.vqgan.modules.flow import ResidualCouplingBlock
  11. from fish_speech.models.vqgan.utils import rand_slice_segments
  12. class SynthesizerTrn(nn.Module):
  13. """
  14. Synthesizer for Training
  15. """
  16. def __init__(
  17. self,
  18. *,
  19. in_channels,
  20. spec_channels,
  21. segment_size,
  22. inter_channels,
  23. hidden_channels,
  24. filter_channels,
  25. n_heads,
  26. n_layers,
  27. n_flows,
  28. n_layers_q,
  29. n_layers_spk,
  30. n_layers_flow,
  31. kernel_size,
  32. p_dropout,
  33. speaker_cond_layer,
  34. resblock,
  35. resblock_kernel_sizes,
  36. resblock_dilation_sizes,
  37. upsample_rates,
  38. upsample_initial_channel,
  39. upsample_kernel_sizes,
  40. gin_channels,
  41. codebook_size,
  42. kmeans_ckpt=None,
  43. ):
  44. super().__init__()
  45. self.segment_size = segment_size
  46. self.vq = VQEncoder(
  47. in_channels=in_channels,
  48. vq_channels=in_channels,
  49. codebook_size=codebook_size,
  50. kmeans_ckpt=kmeans_ckpt,
  51. )
  52. self.enc_p = TextEncoder(
  53. in_channels,
  54. inter_channels,
  55. hidden_channels,
  56. filter_channels,
  57. n_heads,
  58. n_layers,
  59. kernel_size,
  60. p_dropout,
  61. gin_channels=gin_channels,
  62. speaker_cond_layer=speaker_cond_layer,
  63. )
  64. self.enc_spk = SpeakerEncoder(
  65. in_channels=spec_channels,
  66. hidden_channels=inter_channels,
  67. out_channels=gin_channels,
  68. num_heads=n_heads,
  69. num_layers=n_layers_spk,
  70. p_dropout=p_dropout,
  71. )
  72. self.flow = ResidualCouplingBlock(
  73. channels=inter_channels,
  74. hidden_channels=hidden_channels,
  75. kernel_size=5,
  76. dilation_rate=1,
  77. n_layers=n_layers_flow,
  78. n_flows=n_flows,
  79. gin_channels=gin_channels,
  80. )
  81. self.enc_q = PosteriorEncoder(
  82. spec_channels,
  83. inter_channels,
  84. hidden_channels,
  85. 5,
  86. 1,
  87. n_layers_q,
  88. gin_channels=gin_channels,
  89. )
  90. self.dec = Generator(
  91. inter_channels,
  92. resblock,
  93. resblock_kernel_sizes,
  94. resblock_dilation_sizes,
  95. upsample_rates,
  96. upsample_initial_channel,
  97. upsample_kernel_sizes,
  98. gin_channels=gin_channels,
  99. )
  100. def forward(self, x, x_lengths, specs):
  101. g = self.enc_spk(specs, x_lengths)
  102. x, vq_loss = self.vq(x)
  103. _, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
  104. z_q, m_q, logs_q, y_mask = self.enc_q(specs, x_lengths, g=g)
  105. z_p = self.flow(z_q, y_mask, g=g, reverse=False)
  106. z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
  107. o = self.dec(z_slice, g=g)
  108. return (
  109. o,
  110. ids_slice,
  111. x_mask,
  112. y_mask,
  113. (z_q, z_p),
  114. (m_p, logs_p),
  115. (m_q, logs_q),
  116. vq_loss,
  117. )
  118. def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
  119. g = self.enc_spk(specs, x_lengths)
  120. x, vq_loss = self.vq(x)
  121. z_p, m_p, logs_p, h_text, x_mask = self.enc_p(
  122. x, x_lengths, g=g, noise_scale=noise_scale
  123. )
  124. z_p = self.flow(z_p, x_mask, g=g, reverse=True)
  125. o = self.dec((z_p * x_mask)[:, :, :max_len], g=g)
  126. return o
  127. def reconstruct(self, x, x_lengths, max_len=None, noise_scale=0.35):
  128. g = self.enc_spk(x, x_lengths)
  129. z_q, m_q, logs_q, x_mask = self.enc_q(
  130. x, x_lengths, g=g, noise_scale=noise_scale
  131. )
  132. o = self.dec((z_q * x_mask)[:, :, :max_len], g=g)
  133. return o