decoder.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn.utils.parametrizations import weight_norm
  5. from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
  6. from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
  7. from fish_speech.models.vqgan.utils import get_padding, init_weights
  8. class Generator(nn.Module):
  9. def __init__(
  10. self,
  11. initial_channel,
  12. resblock,
  13. resblock_kernel_sizes,
  14. resblock_dilation_sizes,
  15. upsample_rates,
  16. upsample_initial_channel,
  17. upsample_kernel_sizes,
  18. gin_channels=0,
  19. ckpt_path=None,
  20. ):
  21. super(Generator, self).__init__()
  22. self.num_kernels = len(resblock_kernel_sizes)
  23. self.num_upsamples = len(upsample_rates)
  24. self.conv_pre = weight_norm(
  25. nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
  26. )
  27. resblock = ResBlock1 if resblock == "1" else ResBlock2
  28. self.ups = nn.ModuleList()
  29. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  30. self.ups.append(
  31. weight_norm(
  32. nn.ConvTranspose1d(
  33. upsample_initial_channel // (2**i),
  34. upsample_initial_channel // (2 ** (i + 1)),
  35. k,
  36. u,
  37. padding=(k - u) // 2,
  38. )
  39. )
  40. )
  41. self.resblocks = nn.ModuleList()
  42. for i in range(len(self.ups)):
  43. ch = upsample_initial_channel // (2 ** (i + 1))
  44. for j, (k, d) in enumerate(
  45. zip(resblock_kernel_sizes, resblock_dilation_sizes)
  46. ):
  47. self.resblocks.append(resblock(ch, k, d))
  48. self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
  49. self.ups.apply(init_weights)
  50. if gin_channels != 0:
  51. self.cond = nn.Linear(gin_channels, upsample_initial_channel)
  52. if ckpt_path is not None:
  53. self.load_state_dict(torch.load(ckpt_path)["generator"], strict=True)
  54. def forward(self, x, g=None):
  55. x = self.conv_pre(x)
  56. if g is not None:
  57. x = x + self.cond(g.mT).mT
  58. for i in range(self.num_upsamples):
  59. x = F.leaky_relu(x, LRELU_SLOPE)
  60. x = self.ups[i](x)
  61. xs = None
  62. for j in range(self.num_kernels):
  63. if xs is None:
  64. xs = self.resblocks[i * self.num_kernels + j](x)
  65. else:
  66. xs += self.resblocks[i * self.num_kernels + j](x)
  67. x = xs / self.num_kernels
  68. x = F.leaky_relu(x)
  69. x = self.conv_post(x)
  70. x = torch.tanh(x)
  71. return x
  72. def remove_weight_norm(self):
  73. print("Removing weight norm...")
  74. for l in self.ups:
  75. remove_weight_norm(l)
  76. for l in self.resblocks:
  77. l.remove_weight_norm()
  78. class ResBlock1(nn.Module):
  79. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
  80. super(ResBlock1, self).__init__()
  81. self.convs1 = nn.ModuleList(
  82. [
  83. weight_norm(
  84. nn.Conv1d(
  85. channels,
  86. channels,
  87. kernel_size,
  88. 1,
  89. dilation=dilation[0],
  90. padding=get_padding(kernel_size, dilation[0]),
  91. )
  92. ),
  93. weight_norm(
  94. nn.Conv1d(
  95. channels,
  96. channels,
  97. kernel_size,
  98. 1,
  99. dilation=dilation[1],
  100. padding=get_padding(kernel_size, dilation[1]),
  101. )
  102. ),
  103. weight_norm(
  104. nn.Conv1d(
  105. channels,
  106. channels,
  107. kernel_size,
  108. 1,
  109. dilation=dilation[2],
  110. padding=get_padding(kernel_size, dilation[2]),
  111. )
  112. ),
  113. ]
  114. )
  115. self.convs1.apply(init_weights)
  116. self.convs2 = nn.ModuleList(
  117. [
  118. weight_norm(
  119. nn.Conv1d(
  120. channels,
  121. channels,
  122. kernel_size,
  123. 1,
  124. dilation=1,
  125. padding=get_padding(kernel_size, 1),
  126. )
  127. ),
  128. weight_norm(
  129. nn.Conv1d(
  130. channels,
  131. channels,
  132. kernel_size,
  133. 1,
  134. dilation=1,
  135. padding=get_padding(kernel_size, 1),
  136. )
  137. ),
  138. weight_norm(
  139. nn.Conv1d(
  140. channels,
  141. channels,
  142. kernel_size,
  143. 1,
  144. dilation=1,
  145. padding=get_padding(kernel_size, 1),
  146. )
  147. ),
  148. ]
  149. )
  150. self.convs2.apply(init_weights)
  151. def forward(self, x, x_mask=None):
  152. for c1, c2 in zip(self.convs1, self.convs2):
  153. xt = F.leaky_relu(x, LRELU_SLOPE)
  154. if x_mask is not None:
  155. xt = xt * x_mask
  156. xt = c1(xt)
  157. xt = F.leaky_relu(xt, LRELU_SLOPE)
  158. if x_mask is not None:
  159. xt = xt * x_mask
  160. xt = c2(xt)
  161. x = xt + x
  162. if x_mask is not None:
  163. x = x * x_mask
  164. return x
  165. def remove_weight_norm(self):
  166. for l in self.convs1:
  167. remove_weight_norm(l)
  168. for l in self.convs2:
  169. remove_weight_norm(l)
  170. class ResBlock2(nn.Module):
  171. def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
  172. super(ResBlock2, self).__init__()
  173. self.convs = nn.ModuleList(
  174. [
  175. weight_norm(
  176. nn.Conv1d(
  177. channels,
  178. channels,
  179. kernel_size,
  180. 1,
  181. dilation=dilation[0],
  182. padding=get_padding(kernel_size, dilation[0]),
  183. )
  184. ),
  185. weight_norm(
  186. nn.Conv1d(
  187. channels,
  188. channels,
  189. kernel_size,
  190. 1,
  191. dilation=dilation[1],
  192. padding=get_padding(kernel_size, dilation[1]),
  193. )
  194. ),
  195. ]
  196. )
  197. self.convs.apply(init_weights)
  198. def forward(self, x, x_mask=None):
  199. for c in self.convs:
  200. xt = F.leaky_relu(x, LRELU_SLOPE)
  201. if x_mask is not None:
  202. xt = xt * x_mask
  203. xt = c(xt)
  204. x = xt + x
  205. if x_mask is not None:
  206. x = x * x_mask
  207. return x
  208. def remove_weight_norm(self):
  209. for l in self.convs:
  210. remove_weight_norm(l)
  211. if __name__ == "__main__":
  212. import librosa
  213. import soundfile as sf
  214. from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
  215. gen = Generator(
  216. 80,
  217. "1",
  218. [3, 7, 11],
  219. [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  220. [8, 8, 2, 2],
  221. 512,
  222. [16, 16, 4, 4],
  223. ckpt_path="checkpoints/hifigan-v1-universal-22050/g_02500000",
  224. )
  225. spec = LogMelSpectrogram(
  226. sample_rate=22050,
  227. n_fft=1024,
  228. win_length=1024,
  229. hop_length=256,
  230. n_mels=80,
  231. f_min=0.0,
  232. f_max=8000.0,
  233. )
  234. audio = librosa.load("data/StarRail/Chinese/符玄/archive_fuxuan_9.wav", sr=22050)[0]
  235. audio = torch.from_numpy(audio).unsqueeze(0)
  236. spec = spec(audio)
  237. print(spec.shape)
  238. audio = gen(spec)
  239. print(audio.shape)
  240. sf.write("test.wav", audio.detach().squeeze().numpy(), 22050)