firefly.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. # A inference only version of the FireflyGAN model
  2. from functools import partial
  3. from math import prod
  4. from typing import Callable
  5. import numpy as np
  6. import torch
  7. import torch.nn.functional as F
  8. from torch import nn
  9. from torch.nn import Conv1d
  10. from torch.nn.utils.parametrizations import weight_norm
  11. from torch.nn.utils.parametrize import remove_parametrizations
  12. from torch.utils.checkpoint import checkpoint
  13. def init_weights(m, mean=0.0, std=0.01):
  14. classname = m.__class__.__name__
  15. if classname.find("Conv") != -1:
  16. m.weight.data.normal_(mean, std)
  17. def get_padding(kernel_size, dilation=1):
  18. return (kernel_size * dilation - dilation) // 2
  19. class ResBlock1(torch.nn.Module):
  20. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
  21. super().__init__()
  22. self.convs1 = nn.ModuleList(
  23. [
  24. weight_norm(
  25. Conv1d(
  26. channels,
  27. channels,
  28. kernel_size,
  29. 1,
  30. dilation=dilation[0],
  31. padding=get_padding(kernel_size, dilation[0]),
  32. )
  33. ),
  34. weight_norm(
  35. Conv1d(
  36. channels,
  37. channels,
  38. kernel_size,
  39. 1,
  40. dilation=dilation[1],
  41. padding=get_padding(kernel_size, dilation[1]),
  42. )
  43. ),
  44. weight_norm(
  45. Conv1d(
  46. channels,
  47. channels,
  48. kernel_size,
  49. 1,
  50. dilation=dilation[2],
  51. padding=get_padding(kernel_size, dilation[2]),
  52. )
  53. ),
  54. ]
  55. )
  56. self.convs1.apply(init_weights)
  57. self.convs2 = nn.ModuleList(
  58. [
  59. weight_norm(
  60. Conv1d(
  61. channels,
  62. channels,
  63. kernel_size,
  64. 1,
  65. dilation=1,
  66. padding=get_padding(kernel_size, 1),
  67. )
  68. ),
  69. weight_norm(
  70. Conv1d(
  71. channels,
  72. channels,
  73. kernel_size,
  74. 1,
  75. dilation=1,
  76. padding=get_padding(kernel_size, 1),
  77. )
  78. ),
  79. weight_norm(
  80. Conv1d(
  81. channels,
  82. channels,
  83. kernel_size,
  84. 1,
  85. dilation=1,
  86. padding=get_padding(kernel_size, 1),
  87. )
  88. ),
  89. ]
  90. )
  91. self.convs2.apply(init_weights)
  92. def forward(self, x):
  93. for c1, c2 in zip(self.convs1, self.convs2):
  94. xt = F.silu(x)
  95. xt = c1(xt)
  96. xt = F.silu(xt)
  97. xt = c2(xt)
  98. x = xt + x
  99. return x
  100. def remove_parametrizations(self):
  101. for conv in self.convs1:
  102. remove_parametrizations(conv, tensor_name="weight")
  103. for conv in self.convs2:
  104. remove_parametrizations(conv, tensor_name="weight")
  105. class ParralelBlock(nn.Module):
  106. def __init__(
  107. self,
  108. channels: int,
  109. kernel_sizes: tuple[int] = (3, 7, 11),
  110. dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  111. ):
  112. super().__init__()
  113. assert len(kernel_sizes) == len(dilation_sizes)
  114. self.blocks = nn.ModuleList()
  115. for k, d in zip(kernel_sizes, dilation_sizes):
  116. self.blocks.append(ResBlock1(channels, k, d))
  117. def forward(self, x):
  118. return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
  119. def remove_parametrizations(self):
  120. for block in self.blocks:
  121. block.remove_parametrizations()
  122. class HiFiGANGenerator(nn.Module):
  123. def __init__(
  124. self,
  125. *,
  126. hop_length: int = 512,
  127. upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
  128. upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
  129. resblock_kernel_sizes: tuple[int] = (3, 7, 11),
  130. resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  131. num_mels: int = 128,
  132. upsample_initial_channel: int = 512,
  133. use_template: bool = True,
  134. pre_conv_kernel_size: int = 7,
  135. post_conv_kernel_size: int = 7,
  136. post_activation: Callable = partial(nn.SiLU, inplace=True),
  137. ):
  138. super().__init__()
  139. assert (
  140. prod(upsample_rates) == hop_length
  141. ), f"hop_length must be {prod(upsample_rates)}"
  142. self.conv_pre = weight_norm(
  143. nn.Conv1d(
  144. num_mels,
  145. upsample_initial_channel,
  146. pre_conv_kernel_size,
  147. 1,
  148. padding=get_padding(pre_conv_kernel_size),
  149. )
  150. )
  151. self.num_upsamples = len(upsample_rates)
  152. self.num_kernels = len(resblock_kernel_sizes)
  153. self.noise_convs = nn.ModuleList()
  154. self.use_template = use_template
  155. self.ups = nn.ModuleList()
  156. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  157. c_cur = upsample_initial_channel // (2 ** (i + 1))
  158. self.ups.append(
  159. weight_norm(
  160. nn.ConvTranspose1d(
  161. upsample_initial_channel // (2**i),
  162. upsample_initial_channel // (2 ** (i + 1)),
  163. k,
  164. u,
  165. padding=(k - u) // 2,
  166. )
  167. )
  168. )
  169. if not use_template:
  170. continue
  171. if i + 1 < len(upsample_rates):
  172. stride_f0 = np.prod(upsample_rates[i + 1 :])
  173. self.noise_convs.append(
  174. Conv1d(
  175. 1,
  176. c_cur,
  177. kernel_size=stride_f0 * 2,
  178. stride=stride_f0,
  179. padding=stride_f0 // 2,
  180. )
  181. )
  182. else:
  183. self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
  184. self.resblocks = nn.ModuleList()
  185. for i in range(len(self.ups)):
  186. ch = upsample_initial_channel // (2 ** (i + 1))
  187. self.resblocks.append(
  188. ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
  189. )
  190. self.activation_post = post_activation()
  191. self.conv_post = weight_norm(
  192. nn.Conv1d(
  193. ch,
  194. 1,
  195. post_conv_kernel_size,
  196. 1,
  197. padding=get_padding(post_conv_kernel_size),
  198. )
  199. )
  200. self.ups.apply(init_weights)
  201. self.conv_post.apply(init_weights)
  202. def forward(self, x, template=None):
  203. x = self.conv_pre(x)
  204. for i in range(self.num_upsamples):
  205. x = F.silu(x, inplace=True)
  206. x = self.ups[i](x)
  207. if self.use_template:
  208. x = x + self.noise_convs[i](template)
  209. if self.training and self.checkpointing:
  210. x = checkpoint(
  211. self.resblocks[i],
  212. x,
  213. use_reentrant=False,
  214. )
  215. else:
  216. x = self.resblocks[i](x)
  217. x = self.activation_post(x)
  218. x = self.conv_post(x)
  219. x = torch.tanh(x)
  220. return x
  221. def remove_parametrizations(self):
  222. for up in self.ups:
  223. remove_parametrizations(up, tensor_name="weight")
  224. for block in self.resblocks:
  225. block.remove_parametrizations()
  226. remove_parametrizations(self.conv_pre, tensor_name="weight")
  227. remove_parametrizations(self.conv_post, tensor_name="weight")
  228. # DropPath copied from timm library
  229. def drop_path(
  230. x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
  231. ):
  232. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  233. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  234. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  235. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  236. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  237. 'survival rate' as the argument.
  238. """ # noqa: E501
  239. if drop_prob == 0.0 or not training:
  240. return x
  241. keep_prob = 1 - drop_prob
  242. shape = (x.shape[0],) + (1,) * (
  243. x.ndim - 1
  244. ) # work with diff dim tensors, not just 2D ConvNets
  245. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  246. if keep_prob > 0.0 and scale_by_keep:
  247. random_tensor.div_(keep_prob)
  248. return x * random_tensor
  249. class DropPath(nn.Module):
  250. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
  251. def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
  252. super(DropPath, self).__init__()
  253. self.drop_prob = drop_prob
  254. self.scale_by_keep = scale_by_keep
  255. def forward(self, x):
  256. return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
  257. def extra_repr(self):
  258. return f"drop_prob={round(self.drop_prob,3):0.3f}"
  259. class LayerNorm(nn.Module):
  260. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  261. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
  262. shape (batch_size, height, width, channels) while channels_first corresponds to inputs
  263. with shape (batch_size, channels, height, width).
  264. """ # noqa: E501
  265. def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
  266. super().__init__()
  267. self.weight = nn.Parameter(torch.ones(normalized_shape))
  268. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  269. self.eps = eps
  270. self.data_format = data_format
  271. if self.data_format not in ["channels_last", "channels_first"]:
  272. raise NotImplementedError
  273. self.normalized_shape = (normalized_shape,)
  274. def forward(self, x):
  275. if self.data_format == "channels_last":
  276. return F.layer_norm(
  277. x, self.normalized_shape, self.weight, self.bias, self.eps
  278. )
  279. elif self.data_format == "channels_first":
  280. u = x.mean(1, keepdim=True)
  281. s = (x - u).pow(2).mean(1, keepdim=True)
  282. x = (x - u) / torch.sqrt(s + self.eps)
  283. x = self.weight[:, None] * x + self.bias[:, None]
  284. return x
  285. # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
  286. class ConvNeXtBlock(nn.Module):
  287. r"""ConvNeXt Block. There are two equivalent implementations:
  288. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  289. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  290. We use (2) as we find it slightly faster in PyTorch
  291. Args:
  292. dim (int): Number of input channels.
  293. drop_path (float): Stochastic depth rate. Default: 0.0
  294. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
  295. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  296. kernel_size (int): Kernel size for depthwise conv. Default: 7.
  297. dilation (int): Dilation for depthwise conv. Default: 1.
  298. """ # noqa: E501
  299. def __init__(
  300. self,
  301. dim: int,
  302. drop_path: float = 0.0,
  303. layer_scale_init_value: float = 1e-6,
  304. mlp_ratio: float = 4.0,
  305. kernel_size: int = 7,
  306. dilation: int = 1,
  307. ):
  308. super().__init__()
  309. self.dwconv = nn.Conv1d(
  310. dim,
  311. dim,
  312. kernel_size=kernel_size,
  313. padding=int(dilation * (kernel_size - 1) / 2),
  314. groups=dim,
  315. ) # depthwise conv
  316. self.norm = LayerNorm(dim, eps=1e-6)
  317. self.pwconv1 = nn.Linear(
  318. dim, int(mlp_ratio * dim)
  319. ) # pointwise/1x1 convs, implemented with linear layers
  320. self.act = nn.GELU()
  321. self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
  322. self.gamma = (
  323. nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  324. if layer_scale_init_value > 0
  325. else None
  326. )
  327. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  328. def forward(self, x, apply_residual: bool = True):
  329. input = x
  330. x = self.dwconv(x)
  331. x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
  332. x = self.norm(x)
  333. x = self.pwconv1(x)
  334. x = self.act(x)
  335. x = self.pwconv2(x)
  336. if self.gamma is not None:
  337. x = self.gamma * x
  338. x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
  339. x = self.drop_path(x)
  340. if apply_residual:
  341. x = input + x
  342. return x
  343. class ConvNeXtEncoder(nn.Module):
  344. def __init__(
  345. self,
  346. input_channels: int = 3,
  347. depths: list[int] = [3, 3, 9, 3],
  348. dims: list[int] = [96, 192, 384, 768],
  349. drop_path_rate: float = 0.0,
  350. layer_scale_init_value: float = 1e-6,
  351. kernel_size: int = 7,
  352. ):
  353. super().__init__()
  354. assert len(depths) == len(dims)
  355. self.downsample_layers = nn.ModuleList()
  356. stem = nn.Sequential(
  357. nn.Conv1d(
  358. input_channels,
  359. dims[0],
  360. kernel_size=kernel_size,
  361. padding=kernel_size // 2,
  362. padding_mode="zeros",
  363. ),
  364. LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
  365. )
  366. self.downsample_layers.append(stem)
  367. for i in range(len(depths) - 1):
  368. mid_layer = nn.Sequential(
  369. LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
  370. nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
  371. )
  372. self.downsample_layers.append(mid_layer)
  373. self.stages = nn.ModuleList()
  374. dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  375. cur = 0
  376. for i in range(len(depths)):
  377. stage = nn.Sequential(
  378. *[
  379. ConvNeXtBlock(
  380. dim=dims[i],
  381. drop_path=dp_rates[cur + j],
  382. layer_scale_init_value=layer_scale_init_value,
  383. kernel_size=kernel_size,
  384. )
  385. for j in range(depths[i])
  386. ]
  387. )
  388. self.stages.append(stage)
  389. cur += depths[i]
  390. self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
  391. self.apply(self._init_weights)
  392. def _init_weights(self, m):
  393. if isinstance(m, (nn.Conv1d, nn.Linear)):
  394. nn.init.trunc_normal_(m.weight, std=0.02)
  395. nn.init.constant_(m.bias, 0)
  396. def forward(
  397. self,
  398. x: torch.Tensor,
  399. ) -> torch.Tensor:
  400. for i in range(len(self.downsample_layers)):
  401. x = self.downsample_layers[i](x)
  402. x = self.stages[i](x)
  403. return self.norm(x)
  404. class FireflyBase(nn.Module):
  405. def __init__(self, ckpt_path: str = None, pretrained: bool = True):
  406. super().__init__()
  407. self.backbone = ConvNeXtEncoder(
  408. input_channels=128,
  409. depths=[3, 3, 9, 3],
  410. dims=[128, 256, 384, 512],
  411. drop_path_rate=0.2,
  412. kernel_size=7,
  413. )
  414. self.head = HiFiGANGenerator(
  415. hop_length=512,
  416. upsample_rates=[8, 8, 2, 2, 2],
  417. upsample_kernel_sizes=[16, 16, 4, 4, 4],
  418. resblock_kernel_sizes=[3, 7, 11],
  419. resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  420. num_mels=512,
  421. upsample_initial_channel=512,
  422. use_template=False,
  423. pre_conv_kernel_size=13,
  424. post_conv_kernel_size=13,
  425. )
  426. if ckpt_path is not None:
  427. self.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
  428. elif pretrained:
  429. state_dict = torch.hub.load_state_dict_from_url(
  430. "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
  431. map_location="cpu",
  432. )
  433. if "state_dict" in state_dict:
  434. state_dict = state_dict["state_dict"]
  435. if any("generator." in k for k in state_dict):
  436. state_dict = {
  437. k.replace("generator.", ""): v
  438. for k, v in state_dict.items()
  439. if "generator." in k
  440. }
  441. self.load_state_dict(state_dict, strict=True)
  442. self.head.remove_parametrizations()
  443. @torch.no_grad()
  444. def forward(self, x: torch.Tensor) -> torch.Tensor:
  445. x = self.backbone(x)
  446. x = self.head(x)
  447. if x.ndim == 2:
  448. x = x[:, None, :]
  449. return x
  450. if __name__ == "__main__":
  451. model = FireflyBase()
  452. model.eval()
  453. x = torch.randn(1, 128, 128)
  454. with torch.no_grad():
  455. y = model(x)
  456. print(y.shape)