firefly.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. import math
  2. from functools import partial
  3. from math import prod
  4. from typing import Callable
  5. import torch
  6. import torch.nn.functional as F
  7. from torch import nn
  8. from torch.nn.utils.parametrizations import weight_norm
  9. from torch.nn.utils.parametrize import remove_parametrizations
  10. from torch.utils.checkpoint import checkpoint
  11. def sequence_mask(length, max_length=None):
  12. if max_length is None:
  13. max_length = length.max()
  14. x = torch.arange(max_length, dtype=length.dtype, device=length.device)
  15. return x.unsqueeze(0) < length.unsqueeze(1)
  16. def init_weights(m, mean=0.0, std=0.01):
  17. classname = m.__class__.__name__
  18. if classname.find("Conv1D") != -1:
  19. m.weight.data.normal_(mean, std)
  20. def get_padding(kernel_size, dilation=1):
  21. return (kernel_size * dilation - dilation) // 2
  22. def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
  23. """Remove padding from x, handling properly zero padding. Only for 1d!"""
  24. padding_left, padding_right = paddings
  25. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  26. assert (padding_left + padding_right) <= x.shape[-1]
  27. end = x.shape[-1] - padding_right
  28. return x[..., padding_left:end]
  29. def get_extra_padding_for_conv1d(
  30. x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
  31. ) -> int:
  32. """See `pad_for_conv1d`."""
  33. length = x.shape[-1]
  34. n_frames = (length - kernel_size + padding_total) / stride + 1
  35. # for tracer, math.ceil will make onnx graph become constant
  36. if isinstance(n_frames, torch.Tensor):
  37. ideal_length = (torch.ceil(n_frames).long() - 1) * stride + (
  38. kernel_size - padding_total
  39. )
  40. else:
  41. ideal_length = (math.ceil(n_frames) - 1) * stride + (
  42. kernel_size - padding_total
  43. )
  44. return ideal_length - length
  45. def pad1d(
  46. x: torch.Tensor,
  47. paddings: tuple[int, int],
  48. mode: str = "zeros",
  49. value: float = 0.0,
  50. ):
  51. """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
  52. If this is the case, we insert extra 0 padding to the right
  53. before the reflection happen.
  54. """
  55. length = x.shape[-1]
  56. padding_left, padding_right = paddings
  57. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  58. if mode == "reflect":
  59. max_pad = max(padding_left, padding_right)
  60. extra_pad = 0
  61. if length <= max_pad:
  62. extra_pad = max_pad - length + 1
  63. x = F.pad(x, (0, extra_pad))
  64. padded = F.pad(x, paddings, mode, value)
  65. end = padded.shape[-1] - extra_pad
  66. return padded[..., :end]
  67. else:
  68. return F.pad(x, paddings, mode, value)
  69. class FishConvNet(nn.Module):
  70. def __init__(
  71. self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
  72. ):
  73. super(FishConvNet, self).__init__()
  74. self.conv = nn.Conv1d(
  75. in_channels,
  76. out_channels,
  77. kernel_size,
  78. stride=stride,
  79. dilation=dilation,
  80. groups=groups,
  81. )
  82. self.stride = stride
  83. self.kernel_size = (kernel_size - 1) * dilation + 1
  84. self.dilation = dilation
  85. def forward(self, x):
  86. pad = self.kernel_size - self.stride
  87. extra_padding = get_extra_padding_for_conv1d(
  88. x, self.kernel_size, self.stride, pad
  89. )
  90. x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
  91. return self.conv(x).contiguous()
  92. def weight_norm(self, name="weight", dim=0):
  93. self.conv = weight_norm(self.conv, name=name, dim=dim)
  94. return self
  95. def remove_parametrizations(self, name="weight"):
  96. self.conv = remove_parametrizations(self.conv, name)
  97. return self
  98. class FishTransConvNet(nn.Module):
  99. def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
  100. super(FishTransConvNet, self).__init__()
  101. self.conv = nn.ConvTranspose1d(
  102. in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
  103. )
  104. self.stride = stride
  105. self.kernel_size = kernel_size
  106. def forward(self, x):
  107. x = self.conv(x)
  108. pad = self.kernel_size - self.stride
  109. padding_right = math.ceil(pad)
  110. padding_left = pad - padding_right
  111. x = unpad1d(x, (padding_left, padding_right))
  112. return x.contiguous()
  113. def weight_norm(self, name="weight", dim=0):
  114. self.conv = weight_norm(self.conv, name=name, dim=dim)
  115. return self
  116. def remove_parametrizations(self, name="weight"):
  117. self.conv = remove_parametrizations(self.conv, name)
  118. return self
  119. class ResBlock1(torch.nn.Module):
  120. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
  121. super().__init__()
  122. self.convs1 = nn.ModuleList(
  123. [
  124. FishConvNet(
  125. channels, channels, kernel_size, stride=1, dilation=dilation[0]
  126. ).weight_norm(),
  127. FishConvNet(
  128. channels, channels, kernel_size, stride=1, dilation=dilation[1]
  129. ).weight_norm(),
  130. FishConvNet(
  131. channels, channels, kernel_size, stride=1, dilation=dilation[2]
  132. ).weight_norm(),
  133. ]
  134. )
  135. self.convs1.apply(init_weights)
  136. self.convs2 = nn.ModuleList(
  137. [
  138. FishConvNet(
  139. channels, channels, kernel_size, stride=1, dilation=dilation[0]
  140. ).weight_norm(),
  141. FishConvNet(
  142. channels, channels, kernel_size, stride=1, dilation=dilation[1]
  143. ).weight_norm(),
  144. FishConvNet(
  145. channels, channels, kernel_size, stride=1, dilation=dilation[2]
  146. ).weight_norm(),
  147. ]
  148. )
  149. self.convs2.apply(init_weights)
  150. def forward(self, x):
  151. for c1, c2 in zip(self.convs1, self.convs2):
  152. xt = F.silu(x)
  153. xt = c1(xt)
  154. xt = F.silu(xt)
  155. xt = c2(xt)
  156. x = xt + x
  157. return x
  158. def remove_parametrizations(self):
  159. for conv in self.convs1:
  160. conv.remove_parametrizations()
  161. for conv in self.convs2:
  162. conv.remove_parametrizations()
  163. class ParallelBlock(nn.Module):
  164. def __init__(
  165. self,
  166. channels: int,
  167. kernel_sizes: tuple[int] = (3, 7, 11),
  168. dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  169. ):
  170. super().__init__()
  171. assert len(kernel_sizes) == len(dilation_sizes)
  172. self.blocks = nn.ModuleList()
  173. for k, d in zip(kernel_sizes, dilation_sizes):
  174. self.blocks.append(ResBlock1(channels, k, d))
  175. def forward(self, x):
  176. return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
  177. def remove_parametrizations(self):
  178. for block in self.blocks:
  179. block.remove_parametrizations()
  180. class HiFiGANGenerator(nn.Module):
  181. def __init__(
  182. self,
  183. *,
  184. hop_length: int = 512,
  185. upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
  186. upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
  187. resblock_kernel_sizes: tuple[int] = (3, 7, 11),
  188. resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  189. num_mels: int = 128,
  190. upsample_initial_channel: int = 512,
  191. pre_conv_kernel_size: int = 7,
  192. post_conv_kernel_size: int = 7,
  193. post_activation: Callable = partial(nn.SiLU, inplace=True),
  194. ):
  195. super().__init__()
  196. assert (
  197. prod(upsample_rates) == hop_length
  198. ), f"hop_length must be {prod(upsample_rates)}"
  199. self.conv_pre = FishConvNet(
  200. num_mels,
  201. upsample_initial_channel,
  202. pre_conv_kernel_size,
  203. stride=1,
  204. ).weight_norm()
  205. self.num_upsamples = len(upsample_rates)
  206. self.num_kernels = len(resblock_kernel_sizes)
  207. self.noise_convs = nn.ModuleList()
  208. self.ups = nn.ModuleList()
  209. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  210. self.ups.append(
  211. FishTransConvNet(
  212. upsample_initial_channel // (2**i),
  213. upsample_initial_channel // (2 ** (i + 1)),
  214. k,
  215. stride=u,
  216. ).weight_norm()
  217. )
  218. self.resblocks = nn.ModuleList()
  219. for i in range(len(self.ups)):
  220. ch = upsample_initial_channel // (2 ** (i + 1))
  221. self.resblocks.append(
  222. ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
  223. )
  224. self.activation_post = post_activation()
  225. self.conv_post = FishConvNet(
  226. ch, 1, post_conv_kernel_size, stride=1
  227. ).weight_norm()
  228. self.ups.apply(init_weights)
  229. self.conv_post.apply(init_weights)
  230. def forward(self, x):
  231. x = self.conv_pre(x)
  232. for i in range(self.num_upsamples):
  233. x = F.silu(x, inplace=True)
  234. x = self.ups[i](x)
  235. if self.training and self.checkpointing:
  236. x = checkpoint(
  237. self.resblocks[i],
  238. x,
  239. use_reentrant=False,
  240. )
  241. else:
  242. x = self.resblocks[i](x)
  243. x = self.activation_post(x)
  244. x = self.conv_post(x)
  245. x = torch.tanh(x)
  246. return x
  247. def remove_parametrizations(self):
  248. for up in self.ups:
  249. up.remove_parametrizations()
  250. for block in self.resblocks:
  251. block.remove_parametrizations()
  252. self.conv_pre.remove_parametrizations()
  253. self.conv_post.remove_parametrizations()
  254. # DropPath copied from timm library
  255. def drop_path(
  256. x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
  257. ):
  258. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  259. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  260. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  261. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  262. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  263. 'survival rate' as the argument.
  264. """ # noqa: E501
  265. if drop_prob == 0.0 or not training:
  266. return x
  267. keep_prob = 1 - drop_prob
  268. shape = (x.shape[0],) + (1,) * (
  269. x.ndim - 1
  270. ) # work with diff dim tensors, not just 2D ConvNets
  271. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  272. if keep_prob > 0.0 and scale_by_keep:
  273. random_tensor.div_(keep_prob)
  274. return x * random_tensor
  275. class DropPath(nn.Module):
  276. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
  277. def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
  278. super(DropPath, self).__init__()
  279. self.drop_prob = drop_prob
  280. self.scale_by_keep = scale_by_keep
  281. def forward(self, x):
  282. return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
  283. def extra_repr(self):
  284. return f"drop_prob={round(self.drop_prob,3):0.3f}"
  285. class LayerNorm(nn.Module):
  286. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  287. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
  288. shape (batch_size, height, width, channels) while channels_first corresponds to inputs
  289. with shape (batch_size, channels, height, width).
  290. """ # noqa: E501
  291. def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
  292. super().__init__()
  293. self.weight = nn.Parameter(torch.ones(normalized_shape))
  294. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  295. self.eps = eps
  296. self.data_format = data_format
  297. if self.data_format not in ["channels_last", "channels_first"]:
  298. raise NotImplementedError
  299. self.normalized_shape = (normalized_shape,)
  300. def forward(self, x):
  301. if self.data_format == "channels_last":
  302. return F.layer_norm(
  303. x, self.normalized_shape, self.weight, self.bias, self.eps
  304. )
  305. elif self.data_format == "channels_first":
  306. u = x.mean(1, keepdim=True)
  307. s = (x - u).pow(2).mean(1, keepdim=True)
  308. x = (x - u) / torch.sqrt(s + self.eps)
  309. x = self.weight[:, None] * x + self.bias[:, None]
  310. return x
  311. # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
  312. class ConvNeXtBlock(nn.Module):
  313. r"""ConvNeXt Block. There are two equivalent implementations:
  314. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  315. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  316. We use (2) as we find it slightly faster in PyTorch
  317. Args:
  318. dim (int): Number of input channels.
  319. drop_path (float): Stochastic depth rate. Default: 0.0
  320. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
  321. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  322. kernel_size (int): Kernel size for depthwise conv. Default: 7.
  323. dilation (int): Dilation for depthwise conv. Default: 1.
  324. """ # noqa: E501
  325. def __init__(
  326. self,
  327. dim: int,
  328. drop_path: float = 0.0,
  329. layer_scale_init_value: float = 1e-6,
  330. mlp_ratio: float = 4.0,
  331. kernel_size: int = 7,
  332. dilation: int = 1,
  333. ):
  334. super().__init__()
  335. self.dwconv = FishConvNet(
  336. dim,
  337. dim,
  338. kernel_size=kernel_size,
  339. # padding=int(dilation * (kernel_size - 1) / 2),
  340. groups=dim,
  341. ) # depthwise conv
  342. self.norm = LayerNorm(dim, eps=1e-6)
  343. self.pwconv1 = nn.Linear(
  344. dim, int(mlp_ratio * dim)
  345. ) # pointwise/1x1 convs, implemented with linear layers
  346. self.act = nn.GELU()
  347. self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
  348. self.gamma = (
  349. nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  350. if layer_scale_init_value > 0
  351. else None
  352. )
  353. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  354. def forward(self, x, apply_residual: bool = True):
  355. input = x
  356. x = self.dwconv(x)
  357. x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
  358. x = self.norm(x)
  359. x = self.pwconv1(x)
  360. x = self.act(x)
  361. x = self.pwconv2(x)
  362. if self.gamma is not None:
  363. x = self.gamma * x
  364. x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
  365. x = self.drop_path(x)
  366. if apply_residual:
  367. x = input + x
  368. return x
  369. class ConvNeXtEncoder(nn.Module):
  370. def __init__(
  371. self,
  372. input_channels: int = 3,
  373. depths: list[int] = [3, 3, 9, 3],
  374. dims: list[int] = [96, 192, 384, 768],
  375. drop_path_rate: float = 0.0,
  376. layer_scale_init_value: float = 1e-6,
  377. kernel_size: int = 7,
  378. ):
  379. super().__init__()
  380. assert len(depths) == len(dims)
  381. self.downsample_layers = nn.ModuleList()
  382. stem = nn.Sequential(
  383. FishConvNet(
  384. input_channels,
  385. dims[0],
  386. kernel_size=7,
  387. # padding=3,
  388. # padding_mode="replicate",
  389. # padding_mode="zeros",
  390. ),
  391. LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
  392. )
  393. self.downsample_layers.append(stem)
  394. for i in range(len(depths) - 1):
  395. mid_layer = nn.Sequential(
  396. LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
  397. nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
  398. )
  399. self.downsample_layers.append(mid_layer)
  400. self.stages = nn.ModuleList()
  401. dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  402. cur = 0
  403. for i in range(len(depths)):
  404. stage = nn.Sequential(
  405. *[
  406. ConvNeXtBlock(
  407. dim=dims[i],
  408. drop_path=dp_rates[cur + j],
  409. layer_scale_init_value=layer_scale_init_value,
  410. kernel_size=kernel_size,
  411. )
  412. for j in range(depths[i])
  413. ]
  414. )
  415. self.stages.append(stage)
  416. cur += depths[i]
  417. self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
  418. self.apply(self._init_weights)
  419. def _init_weights(self, m):
  420. if isinstance(m, (nn.Conv1d, nn.Linear)):
  421. nn.init.trunc_normal_(m.weight, std=0.02)
  422. nn.init.constant_(m.bias, 0)
  423. def forward(
  424. self,
  425. x: torch.Tensor,
  426. ) -> torch.Tensor:
  427. for i in range(len(self.downsample_layers)):
  428. x = self.downsample_layers[i](x)
  429. x = self.stages[i](x)
  430. return self.norm(x)
  431. class FireflyArchitecture(nn.Module):
  432. def __init__(
  433. self,
  434. backbone: nn.Module,
  435. head: nn.Module,
  436. quantizer: nn.Module,
  437. spec_transform: nn.Module,
  438. ):
  439. super().__init__()
  440. self.backbone = backbone
  441. self.head = head
  442. self.quantizer = quantizer
  443. self.spec_transform = spec_transform
  444. self.downsample_factor = math.prod(self.quantizer.downsample_factor)
  445. def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
  446. if self.spec_transform is not None:
  447. x = self.spec_transform(x)
  448. x = self.backbone(x)
  449. if mask is not None:
  450. x = x * mask
  451. if self.quantizer is not None:
  452. vq_result = self.quantizer(x)
  453. x = vq_result.z
  454. if mask is not None:
  455. x = x * mask
  456. x = self.head(x, template=template)
  457. if x.ndim == 2:
  458. x = x[:, None, :]
  459. if self.vq is not None:
  460. return x, vq_result
  461. return x
  462. def encode(self, audios, audio_lengths):
  463. audios = audios.float()
  464. mels = self.spec_transform(audios)
  465. mel_lengths = audio_lengths // self.spec_transform.hop_length
  466. mel_masks = sequence_mask(mel_lengths, mels.shape[2])
  467. mel_masks_float_conv = mel_masks[:, None, :].float()
  468. mels = mels * mel_masks_float_conv
  469. # Encode
  470. encoded_features = self.backbone(mels) * mel_masks_float_conv
  471. feature_lengths = mel_lengths // self.downsample_factor
  472. return self.quantizer.encode(encoded_features), feature_lengths
  473. def decode(self, indices, feature_lengths) -> torch.Tensor:
  474. mel_masks = sequence_mask(
  475. feature_lengths * self.downsample_factor,
  476. indices.shape[2] * self.downsample_factor,
  477. )
  478. mel_masks_float_conv = mel_masks[:, None, :].float()
  479. audio_lengths = (
  480. feature_lengths * self.downsample_factor * self.spec_transform.hop_length
  481. )
  482. audio_masks = sequence_mask(
  483. audio_lengths,
  484. indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
  485. )
  486. audio_masks_float_conv = audio_masks[:, None, :].float()
  487. z = self.quantizer.decode(indices) * mel_masks_float_conv
  488. x = self.head(z) * audio_masks_float_conv
  489. return x, audio_lengths
  490. def remove_parametrizations(self):
  491. if hasattr(self.backbone, "remove_parametrizations"):
  492. self.backbone.remove_parametrizations()
  493. if hasattr(self.head, "remove_parametrizations"):
  494. self.head.remove_parametrizations()
  495. @property
  496. def device(self):
  497. return next(self.parameters()).device