decoder_v2.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. from functools import partial
  2. from math import prod
  3. from typing import Callable
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from torch.nn import Conv1d
  9. from torch.nn.utils.parametrizations import weight_norm
  10. from torch.nn.utils.parametrize import remove_parametrizations
  11. def init_weights(m, mean=0.0, std=0.01):
  12. classname = m.__class__.__name__
  13. if classname.find("Conv") != -1:
  14. m.weight.data.normal_(mean, std)
  15. def get_padding(kernel_size, dilation=1):
  16. return (kernel_size * dilation - dilation) // 2
  17. class ResBlock(torch.nn.Module):
  18. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
  19. super().__init__()
  20. self.convs1 = nn.ModuleList(
  21. [
  22. weight_norm(
  23. Conv1d(
  24. channels,
  25. channels,
  26. kernel_size,
  27. 1,
  28. dilation=dilation[0],
  29. padding=get_padding(kernel_size, dilation[0]),
  30. )
  31. ),
  32. weight_norm(
  33. Conv1d(
  34. channels,
  35. channels,
  36. kernel_size,
  37. 1,
  38. dilation=dilation[1],
  39. padding=get_padding(kernel_size, dilation[1]),
  40. )
  41. ),
  42. weight_norm(
  43. Conv1d(
  44. channels,
  45. channels,
  46. kernel_size,
  47. 1,
  48. dilation=dilation[2],
  49. padding=get_padding(kernel_size, dilation[2]),
  50. )
  51. ),
  52. ]
  53. )
  54. self.convs1.apply(init_weights)
  55. self.convs2 = nn.ModuleList(
  56. [
  57. weight_norm(
  58. Conv1d(
  59. channels,
  60. channels,
  61. kernel_size,
  62. 1,
  63. dilation=1,
  64. padding=get_padding(kernel_size, 1),
  65. )
  66. ),
  67. weight_norm(
  68. Conv1d(
  69. channels,
  70. channels,
  71. kernel_size,
  72. 1,
  73. dilation=1,
  74. padding=get_padding(kernel_size, 1),
  75. )
  76. ),
  77. weight_norm(
  78. Conv1d(
  79. channels,
  80. channels,
  81. kernel_size,
  82. 1,
  83. dilation=1,
  84. padding=get_padding(kernel_size, 1),
  85. )
  86. ),
  87. ]
  88. )
  89. self.convs2.apply(init_weights)
  90. def forward(self, x):
  91. for c1, c2 in zip(self.convs1, self.convs2):
  92. xt = F.silu(x)
  93. xt = c1(xt)
  94. xt = F.silu(xt)
  95. xt = c2(xt)
  96. x = xt + x
  97. return x
  98. def remove_parametrizations(self):
  99. for conv in self.convs1:
  100. remove_parametrizations(conv)
  101. for conv in self.convs2:
  102. remove_parametrizations(conv)
  103. class ParralelBlock(nn.Module):
  104. def __init__(
  105. self,
  106. channels: int,
  107. kernel_sizes: tuple[int] = (3, 7, 11),
  108. dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  109. ):
  110. super().__init__()
  111. assert len(kernel_sizes) == len(dilation_sizes)
  112. self.blocks = nn.ModuleList()
  113. for k, d in zip(kernel_sizes, dilation_sizes):
  114. self.blocks.append(ResBlock(channels, k, d))
  115. def forward(self, x):
  116. xs = [block(x) for block in self.blocks]
  117. return torch.stack(xs, dim=0).mean(dim=0)
  118. class HiFiGANGenerator(nn.Module):
  119. def __init__(
  120. self,
  121. *,
  122. hop_length: int = 512,
  123. upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
  124. upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
  125. resblock_kernel_sizes: tuple[int] = (3, 7, 11),
  126. resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  127. num_mels: int = 128,
  128. upsample_initial_channel: int = 512,
  129. use_template: bool = True,
  130. pre_conv_kernel_size: int = 7,
  131. post_conv_kernel_size: int = 7,
  132. post_activation: Callable = partial(nn.SiLU, inplace=True),
  133. checkpointing: bool = False,
  134. ckpt_path: str = None,
  135. ):
  136. super().__init__()
  137. assert (
  138. prod(upsample_rates) == hop_length
  139. ), f"hop_length must be {prod(upsample_rates)}"
  140. self.conv_pre = weight_norm(
  141. nn.Conv1d(
  142. num_mels,
  143. upsample_initial_channel,
  144. pre_conv_kernel_size,
  145. 1,
  146. padding=get_padding(pre_conv_kernel_size),
  147. )
  148. )
  149. self.hop_length = hop_length
  150. self.num_upsamples = len(upsample_rates)
  151. self.num_kernels = len(resblock_kernel_sizes)
  152. self.noise_convs = nn.ModuleList()
  153. self.use_template = use_template
  154. self.ups = nn.ModuleList()
  155. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  156. c_cur = upsample_initial_channel // (2 ** (i + 1))
  157. self.ups.append(
  158. weight_norm(
  159. nn.ConvTranspose1d(
  160. upsample_initial_channel // (2**i),
  161. upsample_initial_channel // (2 ** (i + 1)),
  162. k,
  163. u,
  164. padding=(k - u) // 2,
  165. )
  166. )
  167. )
  168. if not use_template:
  169. continue
  170. if i + 1 < len(upsample_rates):
  171. stride_f0 = np.prod(upsample_rates[i + 1 :])
  172. self.noise_convs.append(
  173. Conv1d(
  174. 1,
  175. c_cur,
  176. kernel_size=stride_f0 * 2,
  177. stride=stride_f0,
  178. padding=stride_f0 // 2,
  179. )
  180. )
  181. else:
  182. self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
  183. self.resblocks = nn.ModuleList()
  184. for i in range(len(self.ups)):
  185. ch = upsample_initial_channel // (2 ** (i + 1))
  186. self.resblocks.append(
  187. ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
  188. )
  189. self.activation_post = post_activation()
  190. self.conv_post = weight_norm(
  191. nn.Conv1d(
  192. ch,
  193. 1,
  194. post_conv_kernel_size,
  195. 1,
  196. padding=get_padding(post_conv_kernel_size),
  197. )
  198. )
  199. self.ups.apply(init_weights)
  200. self.conv_post.apply(init_weights)
  201. # Gradient checkpointing
  202. self.checkpointing = checkpointing
  203. if ckpt_path is not None:
  204. states = torch.load(ckpt_path, map_location="cpu")
  205. if "state_dict" in states:
  206. states = states["state_dict"]
  207. states = {
  208. k.replace("generator.", ""): v
  209. for k, v in states.items()
  210. if k.startswith("generator")
  211. }
  212. self.load_state_dict(states, strict=True)
  213. def forward(self, x, template=None):
  214. if self.use_template and template is None:
  215. length = x.shape[-1] * self.hop_length
  216. template = (
  217. torch.randn(x.shape[0], 1, length, device=x.device, dtype=x.dtype)
  218. * 0.003
  219. )
  220. x = self.conv_pre(x)
  221. for i in range(self.num_upsamples):
  222. x = F.silu(x, inplace=True)
  223. x = self.ups[i](x)
  224. if self.use_template:
  225. x = x + self.noise_convs[i](template)
  226. if self.training and self.checkpointing:
  227. x = torch.utils.checkpoint.checkpoint(
  228. self.resblocks[i],
  229. x,
  230. use_reentrant=False,
  231. )
  232. else:
  233. x = self.resblocks[i](x)
  234. x = self.activation_post(x)
  235. x = self.conv_post(x)
  236. x = torch.tanh(x)
  237. return x
  238. def remove_parametrizations(self):
  239. for up in self.ups:
  240. remove_parametrizations(up)
  241. for block in self.resblocks:
  242. block.remove_parametrizations()
  243. remove_parametrizations(self.conv_pre)
  244. remove_parametrizations(self.conv_post)
  245. if __name__ == "__main__":
  246. import torchaudio
  247. from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
  248. spec = LogMelSpectrogram(n_mels=160)
  249. audio, sr = torchaudio.load("test.wav")
  250. audio = audio[None, :]
  251. spec = spec(audio, sample_rate=sr)
  252. model = HiFiGANGenerator(
  253. hop_length=512,
  254. upsample_rates=(8, 8, 2, 2, 2),
  255. upsample_kernel_sizes=(16, 16, 4, 4, 4),
  256. resblock_kernel_sizes=(3, 7, 11),
  257. resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  258. num_mels=160,
  259. upsample_initial_channel=512,
  260. use_template=True,
  261. pre_conv_kernel_size=7,
  262. post_conv_kernel_size=7,
  263. post_activation=partial(nn.SiLU, inplace=True),
  264. ckpt_path="checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt",
  265. )
  266. print(model)
  267. out = model(spec)
  268. print(out.shape)
  269. torchaudio.save("out.wav", out[0], 44100)