decoder_v2.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. from functools import partial
  2. from math import prod
  3. from typing import Callable
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from torch.nn import Conv1d
  9. from torch.nn.utils.parametrizations import weight_norm
  10. from torch.nn.utils.parametrize import remove_parametrizations
  11. def init_weights(m, mean=0.0, std=0.01):
  12. classname = m.__class__.__name__
  13. if classname.find("Conv") != -1:
  14. m.weight.data.normal_(mean, std)
  15. def get_padding(kernel_size, dilation=1):
  16. return (kernel_size * dilation - dilation) // 2
  17. class ResBlock(torch.nn.Module):
  18. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
  19. super().__init__()
  20. self.convs1 = nn.ModuleList(
  21. [
  22. weight_norm(
  23. Conv1d(
  24. channels,
  25. channels,
  26. kernel_size,
  27. 1,
  28. dilation=dilation[0],
  29. padding=get_padding(kernel_size, dilation[0]),
  30. )
  31. ),
  32. weight_norm(
  33. Conv1d(
  34. channels,
  35. channels,
  36. kernel_size,
  37. 1,
  38. dilation=dilation[1],
  39. padding=get_padding(kernel_size, dilation[1]),
  40. )
  41. ),
  42. weight_norm(
  43. Conv1d(
  44. channels,
  45. channels,
  46. kernel_size,
  47. 1,
  48. dilation=dilation[2],
  49. padding=get_padding(kernel_size, dilation[2]),
  50. )
  51. ),
  52. ]
  53. )
  54. self.convs1.apply(init_weights)
  55. self.convs2 = nn.ModuleList(
  56. [
  57. weight_norm(
  58. Conv1d(
  59. channels,
  60. channels,
  61. kernel_size,
  62. 1,
  63. dilation=1,
  64. padding=get_padding(kernel_size, 1),
  65. )
  66. ),
  67. weight_norm(
  68. Conv1d(
  69. channels,
  70. channels,
  71. kernel_size,
  72. 1,
  73. dilation=1,
  74. padding=get_padding(kernel_size, 1),
  75. )
  76. ),
  77. weight_norm(
  78. Conv1d(
  79. channels,
  80. channels,
  81. kernel_size,
  82. 1,
  83. dilation=1,
  84. padding=get_padding(kernel_size, 1),
  85. )
  86. ),
  87. ]
  88. )
  89. self.convs2.apply(init_weights)
  90. def forward(self, x):
  91. for c1, c2 in zip(self.convs1, self.convs2):
  92. xt = F.silu(x)
  93. xt = c1(xt)
  94. xt = F.silu(xt)
  95. xt = c2(xt)
  96. x = xt + x
  97. return x
  98. def remove_parametrizations(self):
  99. for conv in self.convs1:
  100. remove_parametrizations(conv)
  101. for conv in self.convs2:
  102. remove_parametrizations(conv)
  103. class ParralelBlock(nn.Module):
  104. def __init__(
  105. self,
  106. channels: int,
  107. kernel_sizes: tuple[int] = (3, 7, 11),
  108. dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  109. ):
  110. super().__init__()
  111. assert len(kernel_sizes) == len(dilation_sizes)
  112. self.blocks = nn.ModuleList()
  113. for k, d in zip(kernel_sizes, dilation_sizes):
  114. self.blocks.append(ResBlock(channels, k, d))
  115. def forward(self, x):
  116. xs = [block(x) for block in self.blocks]
  117. return torch.stack(xs, dim=0).mean(dim=0)
  118. class HiFiGANGenerator(nn.Module):
  119. def __init__(
  120. self,
  121. *,
  122. hop_length: int = 512,
  123. upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
  124. upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
  125. resblock_kernel_sizes: tuple[int] = (3, 7, 11),
  126. resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  127. num_mels: int = 128,
  128. upsample_initial_channel: int = 512,
  129. use_template: bool = True,
  130. pre_conv_kernel_size: int = 7,
  131. post_conv_kernel_size: int = 7,
  132. post_activation: Callable = partial(nn.SiLU, inplace=True),
  133. checkpointing: bool = False,
  134. ):
  135. super().__init__()
  136. assert (
  137. prod(upsample_rates) == hop_length
  138. ), f"hop_length must be {prod(upsample_rates)}"
  139. self.conv_pre = weight_norm(
  140. nn.Conv1d(
  141. num_mels,
  142. upsample_initial_channel,
  143. pre_conv_kernel_size,
  144. 1,
  145. padding=get_padding(pre_conv_kernel_size),
  146. )
  147. )
  148. self.hop_length = hop_length
  149. self.num_upsamples = len(upsample_rates)
  150. self.num_kernels = len(resblock_kernel_sizes)
  151. self.noise_convs = nn.ModuleList()
  152. self.use_template = use_template
  153. self.ups = nn.ModuleList()
  154. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  155. c_cur = upsample_initial_channel // (2 ** (i + 1))
  156. self.ups.append(
  157. weight_norm(
  158. nn.ConvTranspose1d(
  159. upsample_initial_channel // (2**i),
  160. upsample_initial_channel // (2 ** (i + 1)),
  161. k,
  162. u,
  163. padding=(k - u) // 2,
  164. )
  165. )
  166. )
  167. if not use_template:
  168. continue
  169. if i + 1 < len(upsample_rates):
  170. stride_f0 = np.prod(upsample_rates[i + 1 :])
  171. self.noise_convs.append(
  172. Conv1d(
  173. 1,
  174. c_cur,
  175. kernel_size=stride_f0 * 2,
  176. stride=stride_f0,
  177. padding=stride_f0 // 2,
  178. )
  179. )
  180. else:
  181. self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
  182. self.resblocks = nn.ModuleList()
  183. for i in range(len(self.ups)):
  184. ch = upsample_initial_channel // (2 ** (i + 1))
  185. self.resblocks.append(
  186. ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
  187. )
  188. self.activation_post = post_activation()
  189. self.conv_post = weight_norm(
  190. nn.Conv1d(
  191. ch,
  192. 1,
  193. post_conv_kernel_size,
  194. 1,
  195. padding=get_padding(post_conv_kernel_size),
  196. )
  197. )
  198. self.ups.apply(init_weights)
  199. self.conv_post.apply(init_weights)
  200. # Gradient checkpointing
  201. self.checkpointing = checkpointing
  202. def forward(self, x, template=None):
  203. if self.use_template and template is None:
  204. length = x.shape[-1] * self.hop_length
  205. template = (
  206. torch.randn(x.shape[0], 1, length, device=x.device, dtype=x.dtype)
  207. * 0.003
  208. )
  209. x = self.conv_pre(x)
  210. for i in range(self.num_upsamples):
  211. x = F.silu(x, inplace=True)
  212. x = self.ups[i](x)
  213. if self.use_template:
  214. x = x + self.noise_convs[i](template)
  215. if self.training and self.checkpointing:
  216. x = torch.utils.checkpoint.checkpoint(
  217. self.resblocks[i],
  218. x,
  219. use_reentrant=False,
  220. )
  221. else:
  222. x = self.resblocks[i](x)
  223. x = self.activation_post(x)
  224. x = self.conv_post(x)
  225. x = torch.tanh(x)
  226. return x
  227. def remove_parametrizations(self):
  228. for up in self.ups:
  229. remove_parametrizations(up)
  230. for block in self.resblocks:
  231. block.remove_parametrizations()
  232. remove_parametrizations(self.conv_pre)
  233. remove_parametrizations(self.conv_post)