firefly.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. # A inference only version of the FireflyGAN model
  2. import math
  3. from functools import partial
  4. from math import prod
  5. from typing import Callable
  6. import numpy as np
  7. import torch
  8. import torch.nn.functional as F
  9. from torch import nn
  10. from torch.nn import Conv1d
  11. from torch.nn.utils.parametrizations import weight_norm
  12. from torch.nn.utils.parametrize import remove_parametrizations
  13. from torch.utils.checkpoint import checkpoint
  14. from fish_speech.models.vqgan.utils import sequence_mask
  15. def init_weights(m, mean=0.0, std=0.01):
  16. classname = m.__class__.__name__
  17. if classname.find("Conv") != -1:
  18. m.weight.data.normal_(mean, std)
  19. def get_padding(kernel_size, dilation=1):
  20. return (kernel_size * dilation - dilation) // 2
  21. class ResBlock1(torch.nn.Module):
  22. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
  23. super().__init__()
  24. self.convs1 = nn.ModuleList(
  25. [
  26. weight_norm(
  27. Conv1d(
  28. channels,
  29. channels,
  30. kernel_size,
  31. 1,
  32. dilation=dilation[0],
  33. padding=get_padding(kernel_size, dilation[0]),
  34. )
  35. ),
  36. weight_norm(
  37. Conv1d(
  38. channels,
  39. channels,
  40. kernel_size,
  41. 1,
  42. dilation=dilation[1],
  43. padding=get_padding(kernel_size, dilation[1]),
  44. )
  45. ),
  46. weight_norm(
  47. Conv1d(
  48. channels,
  49. channels,
  50. kernel_size,
  51. 1,
  52. dilation=dilation[2],
  53. padding=get_padding(kernel_size, dilation[2]),
  54. )
  55. ),
  56. ]
  57. )
  58. self.convs1.apply(init_weights)
  59. self.convs2 = nn.ModuleList(
  60. [
  61. weight_norm(
  62. Conv1d(
  63. channels,
  64. channels,
  65. kernel_size,
  66. 1,
  67. dilation=1,
  68. padding=get_padding(kernel_size, 1),
  69. )
  70. ),
  71. weight_norm(
  72. Conv1d(
  73. channels,
  74. channels,
  75. kernel_size,
  76. 1,
  77. dilation=1,
  78. padding=get_padding(kernel_size, 1),
  79. )
  80. ),
  81. weight_norm(
  82. Conv1d(
  83. channels,
  84. channels,
  85. kernel_size,
  86. 1,
  87. dilation=1,
  88. padding=get_padding(kernel_size, 1),
  89. )
  90. ),
  91. ]
  92. )
  93. self.convs2.apply(init_weights)
  94. def forward(self, x):
  95. for c1, c2 in zip(self.convs1, self.convs2):
  96. xt = F.silu(x)
  97. xt = c1(xt)
  98. xt = F.silu(xt)
  99. xt = c2(xt)
  100. x = xt + x
  101. return x
  102. def remove_parametrizations(self):
  103. for conv in self.convs1:
  104. remove_parametrizations(conv, tensor_name="weight")
  105. for conv in self.convs2:
  106. remove_parametrizations(conv, tensor_name="weight")
  107. class ParralelBlock(nn.Module):
  108. def __init__(
  109. self,
  110. channels: int,
  111. kernel_sizes: tuple[int] = (3, 7, 11),
  112. dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  113. ):
  114. super().__init__()
  115. assert len(kernel_sizes) == len(dilation_sizes)
  116. self.blocks = nn.ModuleList()
  117. for k, d in zip(kernel_sizes, dilation_sizes):
  118. self.blocks.append(ResBlock1(channels, k, d))
  119. def forward(self, x):
  120. return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
  121. def remove_parametrizations(self):
  122. for block in self.blocks:
  123. block.remove_parametrizations()
  124. class HiFiGANGenerator(nn.Module):
  125. def __init__(
  126. self,
  127. *,
  128. hop_length: int = 512,
  129. upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
  130. upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
  131. resblock_kernel_sizes: tuple[int] = (3, 7, 11),
  132. resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  133. num_mels: int = 128,
  134. upsample_initial_channel: int = 512,
  135. use_template: bool = True,
  136. pre_conv_kernel_size: int = 7,
  137. post_conv_kernel_size: int = 7,
  138. post_activation: Callable = partial(nn.SiLU, inplace=True),
  139. ):
  140. super().__init__()
  141. assert (
  142. prod(upsample_rates) == hop_length
  143. ), f"hop_length must be {prod(upsample_rates)}"
  144. self.conv_pre = weight_norm(
  145. nn.Conv1d(
  146. num_mels,
  147. upsample_initial_channel,
  148. pre_conv_kernel_size,
  149. 1,
  150. padding=get_padding(pre_conv_kernel_size),
  151. )
  152. )
  153. self.num_upsamples = len(upsample_rates)
  154. self.num_kernels = len(resblock_kernel_sizes)
  155. self.noise_convs = nn.ModuleList()
  156. self.use_template = use_template
  157. self.ups = nn.ModuleList()
  158. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  159. c_cur = upsample_initial_channel // (2 ** (i + 1))
  160. self.ups.append(
  161. weight_norm(
  162. nn.ConvTranspose1d(
  163. upsample_initial_channel // (2**i),
  164. upsample_initial_channel // (2 ** (i + 1)),
  165. k,
  166. u,
  167. padding=(k - u) // 2,
  168. )
  169. )
  170. )
  171. if not use_template:
  172. continue
  173. if i + 1 < len(upsample_rates):
  174. stride_f0 = np.prod(upsample_rates[i + 1 :])
  175. self.noise_convs.append(
  176. Conv1d(
  177. 1,
  178. c_cur,
  179. kernel_size=stride_f0 * 2,
  180. stride=stride_f0,
  181. padding=stride_f0 // 2,
  182. )
  183. )
  184. else:
  185. self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
  186. self.resblocks = nn.ModuleList()
  187. for i in range(len(self.ups)):
  188. ch = upsample_initial_channel // (2 ** (i + 1))
  189. self.resblocks.append(
  190. ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
  191. )
  192. self.activation_post = post_activation()
  193. self.conv_post = weight_norm(
  194. nn.Conv1d(
  195. ch,
  196. 1,
  197. post_conv_kernel_size,
  198. 1,
  199. padding=get_padding(post_conv_kernel_size),
  200. )
  201. )
  202. self.ups.apply(init_weights)
  203. self.conv_post.apply(init_weights)
  204. def forward(self, x, template=None):
  205. x = self.conv_pre(x)
  206. for i in range(self.num_upsamples):
  207. x = F.silu(x, inplace=True)
  208. x = self.ups[i](x)
  209. if self.use_template:
  210. x = x + self.noise_convs[i](template)
  211. if self.training:
  212. x = checkpoint(
  213. self.resblocks[i],
  214. x,
  215. use_reentrant=False,
  216. )
  217. else:
  218. x = self.resblocks[i](x)
  219. x = self.activation_post(x)
  220. x = self.conv_post(x)
  221. x = torch.tanh(x)
  222. return x
  223. def remove_parametrizations(self):
  224. for up in self.ups:
  225. remove_parametrizations(up, tensor_name="weight")
  226. for block in self.resblocks:
  227. block.remove_parametrizations()
  228. remove_parametrizations(self.conv_pre, tensor_name="weight")
  229. remove_parametrizations(self.conv_post, tensor_name="weight")
  230. # DropPath copied from timm library
  231. def drop_path(
  232. x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
  233. ):
  234. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  235. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  236. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  237. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  238. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  239. 'survival rate' as the argument.
  240. """ # noqa: E501
  241. if drop_prob == 0.0 or not training:
  242. return x
  243. keep_prob = 1 - drop_prob
  244. shape = (x.shape[0],) + (1,) * (
  245. x.ndim - 1
  246. ) # work with diff dim tensors, not just 2D ConvNets
  247. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  248. if keep_prob > 0.0 and scale_by_keep:
  249. random_tensor.div_(keep_prob)
  250. return x * random_tensor
  251. class DropPath(nn.Module):
  252. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
  253. def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
  254. super(DropPath, self).__init__()
  255. self.drop_prob = drop_prob
  256. self.scale_by_keep = scale_by_keep
  257. def forward(self, x):
  258. return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
  259. def extra_repr(self):
  260. return f"drop_prob={round(self.drop_prob,3):0.3f}"
  261. class LayerNorm(nn.Module):
  262. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  263. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
  264. shape (batch_size, height, width, channels) while channels_first corresponds to inputs
  265. with shape (batch_size, channels, height, width).
  266. """ # noqa: E501
  267. def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
  268. super().__init__()
  269. self.weight = nn.Parameter(torch.ones(normalized_shape))
  270. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  271. self.eps = eps
  272. self.data_format = data_format
  273. if self.data_format not in ["channels_last", "channels_first"]:
  274. raise NotImplementedError
  275. self.normalized_shape = (normalized_shape,)
  276. def forward(self, x):
  277. if self.data_format == "channels_last":
  278. return F.layer_norm(
  279. x, self.normalized_shape, self.weight, self.bias, self.eps
  280. )
  281. elif self.data_format == "channels_first":
  282. u = x.mean(1, keepdim=True)
  283. s = (x - u).pow(2).mean(1, keepdim=True)
  284. x = (x - u) / torch.sqrt(s + self.eps)
  285. x = self.weight[:, None] * x + self.bias[:, None]
  286. return x
  287. # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
  288. class ConvNeXtBlock(nn.Module):
  289. r"""ConvNeXt Block. There are two equivalent implementations:
  290. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  291. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  292. We use (2) as we find it slightly faster in PyTorch
  293. Args:
  294. dim (int): Number of input channels.
  295. drop_path (float): Stochastic depth rate. Default: 0.0
  296. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
  297. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  298. kernel_size (int): Kernel size for depthwise conv. Default: 7.
  299. dilation (int): Dilation for depthwise conv. Default: 1.
  300. """ # noqa: E501
  301. def __init__(
  302. self,
  303. dim: int,
  304. drop_path: float = 0.0,
  305. layer_scale_init_value: float = 1e-6,
  306. mlp_ratio: float = 4.0,
  307. kernel_size: int = 7,
  308. dilation: int = 1,
  309. ):
  310. super().__init__()
  311. self.dwconv = nn.Conv1d(
  312. dim,
  313. dim,
  314. kernel_size=kernel_size,
  315. padding=int(dilation * (kernel_size - 1) / 2),
  316. groups=dim,
  317. ) # depthwise conv
  318. self.norm = LayerNorm(dim, eps=1e-6)
  319. self.pwconv1 = nn.Linear(
  320. dim, int(mlp_ratio * dim)
  321. ) # pointwise/1x1 convs, implemented with linear layers
  322. self.act = nn.GELU()
  323. self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
  324. self.gamma = (
  325. nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  326. if layer_scale_init_value > 0
  327. else None
  328. )
  329. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  330. def forward(self, x, apply_residual: bool = True):
  331. input = x
  332. x = self.dwconv(x)
  333. x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
  334. x = self.norm(x)
  335. x = self.pwconv1(x)
  336. x = self.act(x)
  337. x = self.pwconv2(x)
  338. if self.gamma is not None:
  339. x = self.gamma * x
  340. x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
  341. x = self.drop_path(x)
  342. if apply_residual:
  343. x = input + x
  344. return x
  345. class ConvNeXtEncoder(nn.Module):
  346. def __init__(
  347. self,
  348. input_channels: int = 3,
  349. depths: list[int] = [3, 3, 9, 3],
  350. dims: list[int] = [96, 192, 384, 768],
  351. drop_path_rate: float = 0.0,
  352. layer_scale_init_value: float = 1e-6,
  353. kernel_size: int = 7,
  354. ):
  355. super().__init__()
  356. assert len(depths) == len(dims)
  357. self.downsample_layers = nn.ModuleList()
  358. stem = nn.Sequential(
  359. nn.Conv1d(
  360. input_channels,
  361. dims[0],
  362. kernel_size=kernel_size,
  363. padding=kernel_size // 2,
  364. padding_mode="zeros",
  365. ),
  366. LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
  367. )
  368. self.downsample_layers.append(stem)
  369. for i in range(len(depths) - 1):
  370. mid_layer = nn.Sequential(
  371. LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
  372. nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
  373. )
  374. self.downsample_layers.append(mid_layer)
  375. self.stages = nn.ModuleList()
  376. dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  377. cur = 0
  378. for i in range(len(depths)):
  379. stage = nn.Sequential(
  380. *[
  381. ConvNeXtBlock(
  382. dim=dims[i],
  383. drop_path=dp_rates[cur + j],
  384. layer_scale_init_value=layer_scale_init_value,
  385. kernel_size=kernel_size,
  386. )
  387. for j in range(depths[i])
  388. ]
  389. )
  390. self.stages.append(stage)
  391. cur += depths[i]
  392. self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
  393. self.apply(self._init_weights)
  394. def _init_weights(self, m):
  395. if isinstance(m, (nn.Conv1d, nn.Linear)):
  396. nn.init.trunc_normal_(m.weight, std=0.02)
  397. nn.init.constant_(m.bias, 0)
  398. def forward(
  399. self,
  400. x: torch.Tensor,
  401. ) -> torch.Tensor:
  402. for i in range(len(self.downsample_layers)):
  403. x = self.downsample_layers[i](x)
  404. x = self.stages[i](x)
  405. return self.norm(x)
  406. class FireflyArchitecture(nn.Module):
  407. def __init__(
  408. self,
  409. backbone: nn.Module,
  410. head: nn.Module,
  411. quantizer: nn.Module,
  412. spec_transform: nn.Module,
  413. ):
  414. super().__init__()
  415. self.backbone = backbone
  416. self.head = head
  417. self.quantizer = quantizer
  418. self.spec_transform = spec_transform
  419. def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
  420. if self.spec_transform is not None:
  421. x = self.spec_transform(x)
  422. x = self.backbone(x)
  423. if mask is not None:
  424. x = x * mask
  425. if self.quantizer is not None:
  426. vq_result = self.quantizer(x)
  427. x = vq_result.z
  428. if mask is not None:
  429. x = x * mask
  430. x = self.head(x, template=template)
  431. if x.ndim == 2:
  432. x = x[:, None, :]
  433. if self.quantizer is not None:
  434. return x, vq_result
  435. return x
  436. def encode(self, audios, audio_lengths):
  437. audios = audios.float()
  438. mels = self.spec_transform(audios)
  439. mel_lengths = audio_lengths // self.spec_transform.hop_length
  440. mel_masks = sequence_mask(mel_lengths, mels.shape[2])
  441. mel_masks_float_conv = mel_masks[:, None, :].float()
  442. mels = mels * mel_masks_float_conv
  443. # Encode
  444. encoded_features = self.backbone(mels) * mel_masks_float_conv
  445. feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
  446. return self.quantizer.encode(encoded_features), feature_lengths
  447. def decode(self, indices, feature_lengths) -> torch.Tensor:
  448. factor = math.prod(self.quantizer.downsample_factor)
  449. mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
  450. mel_masks_float_conv = mel_masks[:, None, :].float()
  451. audio_masks = sequence_mask(
  452. feature_lengths * factor * self.spec_transform.hop_length,
  453. indices.shape[2] * factor * self.spec_transform.hop_length,
  454. )
  455. audio_masks_float_conv = audio_masks[:, None, :].float()
  456. z = self.quantizer.decode(indices) * mel_masks_float_conv
  457. x = self.head(z) * audio_masks_float_conv
  458. return x
  459. def remove_parametrizations(self):
  460. if hasattr(self.backbone, "remove_parametrizations"):
  461. self.backbone.remove_parametrizations()
  462. if hasattr(self.head, "remove_parametrizations"):
  463. self.head.remove_parametrizations()
  464. @property
  465. def device(self):
  466. return next(self.parameters()).device
  467. class FireflyBase(nn.Module):
  468. def __init__(self, ckpt_path: str = None, pretrained: bool = True):
  469. super().__init__()
  470. self.backbone = ConvNeXtEncoder(
  471. input_channels=128,
  472. depths=[3, 3, 9, 3],
  473. dims=[128, 256, 384, 512],
  474. drop_path_rate=0.2,
  475. kernel_size=7,
  476. )
  477. self.head = HiFiGANGenerator(
  478. hop_length=512,
  479. upsample_rates=[8, 8, 2, 2, 2],
  480. upsample_kernel_sizes=[16, 16, 4, 4, 4],
  481. resblock_kernel_sizes=[3, 7, 11],
  482. resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
  483. num_mels=512,
  484. upsample_initial_channel=512,
  485. use_template=False,
  486. pre_conv_kernel_size=13,
  487. post_conv_kernel_size=13,
  488. )
  489. if ckpt_path is not None:
  490. state_dict = torch.load(ckpt_path, map_location="cpu")
  491. elif pretrained:
  492. state_dict = torch.hub.load_state_dict_from_url(
  493. "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
  494. map_location="cpu",
  495. model_dir="checkpoints",
  496. )
  497. if "state_dict" in state_dict:
  498. state_dict = state_dict["state_dict"]
  499. if any("generator." in k for k in state_dict):
  500. state_dict = {
  501. k.replace("generator.", ""): v
  502. for k, v in state_dict.items()
  503. if "generator." in k
  504. }
  505. self.load_state_dict(state_dict, strict=True)
  506. self.head.remove_parametrizations()
  507. @torch.no_grad()
  508. def forward(self, x: torch.Tensor) -> torch.Tensor:
  509. x = self.backbone(x)
  510. x = self.head(x)
  511. if x.ndim == 2:
  512. x = x[:, None, :]
  513. return x
  514. if __name__ == "__main__":
  515. model = FireflyBase()
  516. model.eval()
  517. x = torch.randn(1, 128, 128)
  518. with torch.no_grad():
  519. y = model(x)
  520. print(y.shape)