decoder.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn.utils.parametrizations import weight_norm
  5. from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
  6. from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
  7. from fish_speech.models.vqgan.utils import get_padding, init_weights
  8. class Generator(nn.Module):
  9. def __init__(
  10. self,
  11. initial_channel,
  12. resblock,
  13. resblock_kernel_sizes,
  14. resblock_dilation_sizes,
  15. upsample_rates,
  16. upsample_initial_channel,
  17. upsample_kernel_sizes,
  18. gin_channels=0,
  19. ):
  20. super(Generator, self).__init__()
  21. self.num_kernels = len(resblock_kernel_sizes)
  22. self.num_upsamples = len(upsample_rates)
  23. self.conv_pre = nn.Conv1d(
  24. initial_channel, upsample_initial_channel, 7, 1, padding=3
  25. )
  26. resblock = ResBlock1 if resblock == "1" else ResBlock2
  27. self.ups = nn.ModuleList()
  28. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  29. self.ups.append(
  30. weight_norm(
  31. nn.ConvTranspose1d(
  32. upsample_initial_channel // (2**i),
  33. upsample_initial_channel // (2 ** (i + 1)),
  34. k,
  35. u,
  36. padding=(k - u) // 2,
  37. )
  38. )
  39. )
  40. self.resblocks = nn.ModuleList()
  41. for i in range(len(self.ups)):
  42. ch = upsample_initial_channel // (2 ** (i + 1))
  43. for j, (k, d) in enumerate(
  44. zip(resblock_kernel_sizes, resblock_dilation_sizes)
  45. ):
  46. self.resblocks.append(resblock(ch, k, d))
  47. self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
  48. self.ups.apply(init_weights)
  49. if gin_channels != 0:
  50. self.cond = nn.Linear(gin_channels, upsample_initial_channel)
  51. def forward(self, x, g=None):
  52. x = self.conv_pre(x)
  53. if g is not None:
  54. x = x + self.cond(g.mT).mT
  55. for i in range(self.num_upsamples):
  56. x = F.leaky_relu(x, LRELU_SLOPE)
  57. x = self.ups[i](x)
  58. xs = None
  59. for j in range(self.num_kernels):
  60. if xs is None:
  61. xs = self.resblocks[i * self.num_kernels + j](x)
  62. else:
  63. xs += self.resblocks[i * self.num_kernels + j](x)
  64. x = xs / self.num_kernels
  65. x = F.leaky_relu(x)
  66. x = self.conv_post(x)
  67. x = torch.tanh(x)
  68. return x
  69. def remove_weight_norm(self):
  70. print("Removing weight norm...")
  71. for l in self.ups:
  72. remove_weight_norm(l)
  73. for l in self.resblocks:
  74. l.remove_weight_norm()
  75. class ResBlock1(nn.Module):
  76. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
  77. super(ResBlock1, self).__init__()
  78. self.convs1 = nn.ModuleList(
  79. [
  80. weight_norm(
  81. nn.Conv1d(
  82. channels,
  83. channels,
  84. kernel_size,
  85. 1,
  86. dilation=dilation[0],
  87. padding=get_padding(kernel_size, dilation[0]),
  88. )
  89. ),
  90. weight_norm(
  91. nn.Conv1d(
  92. channels,
  93. channels,
  94. kernel_size,
  95. 1,
  96. dilation=dilation[1],
  97. padding=get_padding(kernel_size, dilation[1]),
  98. )
  99. ),
  100. weight_norm(
  101. nn.Conv1d(
  102. channels,
  103. channels,
  104. kernel_size,
  105. 1,
  106. dilation=dilation[2],
  107. padding=get_padding(kernel_size, dilation[2]),
  108. )
  109. ),
  110. ]
  111. )
  112. self.convs1.apply(init_weights)
  113. self.convs2 = nn.ModuleList(
  114. [
  115. weight_norm(
  116. nn.Conv1d(
  117. channels,
  118. channels,
  119. kernel_size,
  120. 1,
  121. dilation=1,
  122. padding=get_padding(kernel_size, 1),
  123. )
  124. ),
  125. weight_norm(
  126. nn.Conv1d(
  127. channels,
  128. channels,
  129. kernel_size,
  130. 1,
  131. dilation=1,
  132. padding=get_padding(kernel_size, 1),
  133. )
  134. ),
  135. weight_norm(
  136. nn.Conv1d(
  137. channels,
  138. channels,
  139. kernel_size,
  140. 1,
  141. dilation=1,
  142. padding=get_padding(kernel_size, 1),
  143. )
  144. ),
  145. ]
  146. )
  147. self.convs2.apply(init_weights)
  148. def forward(self, x, x_mask=None):
  149. for c1, c2 in zip(self.convs1, self.convs2):
  150. xt = F.leaky_relu(x, LRELU_SLOPE)
  151. if x_mask is not None:
  152. xt = xt * x_mask
  153. xt = c1(xt)
  154. xt = F.leaky_relu(xt, LRELU_SLOPE)
  155. if x_mask is not None:
  156. xt = xt * x_mask
  157. xt = c2(xt)
  158. x = xt + x
  159. if x_mask is not None:
  160. x = x * x_mask
  161. return x
  162. def remove_weight_norm(self):
  163. for l in self.convs1:
  164. remove_weight_norm(l)
  165. for l in self.convs2:
  166. remove_weight_norm(l)
  167. class ResBlock2(nn.Module):
  168. def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
  169. super(ResBlock2, self).__init__()
  170. self.convs = nn.ModuleList(
  171. [
  172. weight_norm(
  173. nn.Conv1d(
  174. channels,
  175. channels,
  176. kernel_size,
  177. 1,
  178. dilation=dilation[0],
  179. padding=get_padding(kernel_size, dilation[0]),
  180. )
  181. ),
  182. weight_norm(
  183. nn.Conv1d(
  184. channels,
  185. channels,
  186. kernel_size,
  187. 1,
  188. dilation=dilation[1],
  189. padding=get_padding(kernel_size, dilation[1]),
  190. )
  191. ),
  192. ]
  193. )
  194. self.convs.apply(init_weights)
  195. def forward(self, x, x_mask=None):
  196. for c in self.convs:
  197. xt = F.leaky_relu(x, LRELU_SLOPE)
  198. if x_mask is not None:
  199. xt = xt * x_mask
  200. xt = c(xt)
  201. x = xt + x
  202. if x_mask is not None:
  203. x = x * x_mask
  204. return x
  205. def remove_weight_norm(self):
  206. for l in self.convs:
  207. remove_weight_norm(l)