models.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import torch
  2. from torch import nn
  3. from fish_speech.models.vqgan.modules.decoder import Generator
  4. from fish_speech.models.vqgan.modules.encoders import (
  5. PosteriorEncoder,
  6. SpeakerEncoder,
  7. TextEncoder,
  8. VQEncoder,
  9. )
  10. from fish_speech.models.vqgan.modules.flow import ResidualCouplingBlock
  11. from fish_speech.models.vqgan.utils import rand_slice_segments, sequence_mask
  12. class SynthesizerTrn(nn.Module):
  13. """
  14. Synthesizer for Training
  15. """
  16. def __init__(
  17. self,
  18. *,
  19. in_channels,
  20. spec_channels,
  21. segment_size,
  22. inter_channels,
  23. hidden_channels,
  24. filter_channels,
  25. n_heads,
  26. n_layers,
  27. n_flows,
  28. n_layers_q,
  29. n_layers_spk,
  30. n_layers_flow,
  31. kernel_size,
  32. p_dropout,
  33. speaker_cond_layer,
  34. resblock,
  35. resblock_kernel_sizes,
  36. resblock_dilation_sizes,
  37. upsample_rates,
  38. upsample_initial_channel,
  39. upsample_kernel_sizes,
  40. gin_channels,
  41. codebook_size,
  42. kmeans_ckpt=None,
  43. ):
  44. super().__init__()
  45. self.segment_size = segment_size
  46. # self.vq = VQEncoder(
  47. # in_channels=in_channels,
  48. # vq_channels=in_channels,
  49. # codebook_size=codebook_size,
  50. # kmeans_ckpt=kmeans_ckpt,
  51. # )
  52. self.enc_p = TextEncoder(
  53. in_channels,
  54. inter_channels,
  55. hidden_channels,
  56. filter_channels,
  57. n_heads,
  58. n_layers,
  59. kernel_size,
  60. p_dropout,
  61. gin_channels=gin_channels,
  62. speaker_cond_layer=speaker_cond_layer,
  63. )
  64. self.enc_spk = SpeakerEncoder(
  65. in_channels=spec_channels,
  66. hidden_channels=inter_channels,
  67. out_channels=gin_channels,
  68. num_heads=n_heads,
  69. num_layers=n_layers_spk,
  70. p_dropout=p_dropout,
  71. )
  72. self.flow = ResidualCouplingBlock(
  73. channels=inter_channels,
  74. hidden_channels=hidden_channels,
  75. kernel_size=5,
  76. dilation_rate=1,
  77. n_layers=n_layers_flow,
  78. n_flows=n_flows,
  79. gin_channels=gin_channels,
  80. )
  81. self.enc_q = PosteriorEncoder(
  82. spec_channels,
  83. inter_channels,
  84. hidden_channels,
  85. 5,
  86. 1,
  87. n_layers_q,
  88. gin_channels=gin_channels,
  89. )
  90. self.dec = Generator(
  91. inter_channels,
  92. resblock,
  93. resblock_kernel_sizes,
  94. resblock_dilation_sizes,
  95. upsample_rates,
  96. upsample_initial_channel,
  97. upsample_kernel_sizes,
  98. gin_channels=gin_channels,
  99. )
  100. def forward(self, x, x_lengths, specs):
  101. # x = x.mT
  102. min_length = min(x.shape[1], specs.shape[2])
  103. if min_length % 2 != 0:
  104. min_length -= 1
  105. x = x[:, :min_length]
  106. specs = specs[:, :, :min_length]
  107. x_lengths = torch.clamp(x_lengths, max=min_length)
  108. spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
  109. specs.dtype
  110. )
  111. x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
  112. g = self.enc_spk(specs, spec_masks)
  113. # with torch.no_grad():
  114. # x, _ = self.vq(x, x_masks)
  115. # vq_loss = 0
  116. _, m_p, logs_p, _, _ = self.enc_p(x, x_masks, g=g)
  117. z_q, m_q, logs_q, _ = self.enc_q(specs, spec_masks, g=g)
  118. z_p = self.flow(z_q, spec_masks, g=g, reverse=False)
  119. z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
  120. o = self.dec(z_slice, g=g)
  121. return (
  122. o,
  123. ids_slice,
  124. x_masks,
  125. spec_masks,
  126. (z_q, z_p),
  127. (m_p, logs_p),
  128. (m_q, logs_q),
  129. # vq_loss,
  130. )
  131. def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
  132. # x = x.mT
  133. spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
  134. specs.dtype
  135. )
  136. # print(x_lengths, x.shape)
  137. x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
  138. g = self.enc_spk(specs, spec_masks)
  139. # x, vq_loss = self.vq(x, x_masks)
  140. z_p, m_p, logs_p, h_text, _ = self.enc_p(
  141. x, x_masks, g=g, noise_scale=noise_scale
  142. )
  143. z_p = self.flow(z_p, x_masks, g=g, reverse=True)
  144. o = self.dec((z_p * x_masks)[:, :, :max_len], g=g)
  145. return o
  146. def reconstruct(self, specs, spec_lengths, max_len=None, noise_scale=0.35):
  147. spec_masks = torch.unsqueeze(sequence_mask(spec_lengths, specs.shape[2]), 1).to(
  148. specs.dtype
  149. )
  150. g = self.enc_spk(specs, spec_masks)
  151. z_q, m_q, logs_q, _ = self.enc_q(
  152. specs, spec_masks, g=g, noise_scale=noise_scale
  153. )
  154. o = self.dec((z_q * spec_masks)[:, :, :max_len], g=g)
  155. return o