firefly.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. import math
  2. from functools import partial
  3. from math import prod
  4. from typing import Callable
  5. import torch
  6. import torch.nn.functional as F
  7. from torch import nn
  8. from torch.nn.utils.parametrizations import weight_norm
  9. from torch.nn.utils.parametrize import remove_parametrizations
  10. from torch.utils.checkpoint import checkpoint
  11. def sequence_mask(length, max_length=None):
  12. if max_length is None:
  13. max_length = length.max()
  14. x = torch.arange(max_length, dtype=length.dtype, device=length.device)
  15. return x.unsqueeze(0) < length.unsqueeze(1)
  16. def init_weights(m, mean=0.0, std=0.01):
  17. classname = m.__class__.__name__
  18. if classname.find("Conv1D") != -1:
  19. m.weight.data.normal_(mean, std)
  20. def get_padding(kernel_size, dilation=1):
  21. return (kernel_size * dilation - dilation) // 2
  22. def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
  23. """Remove padding from x, handling properly zero padding. Only for 1d!"""
  24. padding_left, padding_right = paddings
  25. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  26. assert (padding_left + padding_right) <= x.shape[-1]
  27. end = x.shape[-1] - padding_right
  28. return x[..., padding_left:end]
  29. def get_extra_padding_for_conv1d(
  30. x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
  31. ) -> int:
  32. """See `pad_for_conv1d`."""
  33. length = x.shape[-1]
  34. n_frames = (length - kernel_size + padding_total) / stride + 1
  35. ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
  36. return ideal_length - length
  37. def pad1d(
  38. x: torch.Tensor,
  39. paddings: tuple[int, int],
  40. mode: str = "zeros",
  41. value: float = 0.0,
  42. ):
  43. """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
  44. If this is the case, we insert extra 0 padding to the right
  45. before the reflection happen.
  46. """
  47. length = x.shape[-1]
  48. padding_left, padding_right = paddings
  49. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  50. if mode == "reflect":
  51. max_pad = max(padding_left, padding_right)
  52. extra_pad = 0
  53. if length <= max_pad:
  54. extra_pad = max_pad - length + 1
  55. x = F.pad(x, (0, extra_pad))
  56. padded = F.pad(x, paddings, mode, value)
  57. end = padded.shape[-1] - extra_pad
  58. return padded[..., :end]
  59. else:
  60. return F.pad(x, paddings, mode, value)
  61. class FishConvNet(nn.Module):
  62. def __init__(
  63. self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
  64. ):
  65. super(FishConvNet, self).__init__()
  66. self.conv = nn.Conv1d(
  67. in_channels,
  68. out_channels,
  69. kernel_size,
  70. stride=stride,
  71. dilation=dilation,
  72. groups=groups,
  73. )
  74. self.stride = stride
  75. self.kernel_size = (kernel_size - 1) * dilation + 1
  76. self.dilation = dilation
  77. def forward(self, x):
  78. pad = self.kernel_size - self.stride
  79. extra_padding = get_extra_padding_for_conv1d(
  80. x, self.kernel_size, self.stride, pad
  81. )
  82. x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
  83. return self.conv(x).contiguous()
  84. def weight_norm(self, name="weight", dim=0):
  85. self.conv = weight_norm(self.conv, name=name, dim=dim)
  86. return self
  87. def remove_parametrizations(self, name="weight"):
  88. self.conv = remove_parametrizations(self.conv, name)
  89. return self
  90. class FishTransConvNet(nn.Module):
  91. def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
  92. super(FishTransConvNet, self).__init__()
  93. self.conv = nn.ConvTranspose1d(
  94. in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
  95. )
  96. self.stride = stride
  97. self.kernel_size = kernel_size
  98. def forward(self, x):
  99. x = self.conv(x)
  100. pad = self.kernel_size - self.stride
  101. padding_right = math.ceil(pad)
  102. padding_left = pad - padding_right
  103. x = unpad1d(x, (padding_left, padding_right))
  104. return x.contiguous()
  105. def weight_norm(self, name="weight", dim=0):
  106. self.conv = weight_norm(self.conv, name=name, dim=dim)
  107. return self
  108. def remove_parametrizations(self, name="weight"):
  109. self.conv = remove_parametrizations(self.conv, name)
  110. return self
  111. class ResBlock1(torch.nn.Module):
  112. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
  113. super().__init__()
  114. self.convs1 = nn.ModuleList(
  115. [
  116. FishConvNet(
  117. channels, channels, kernel_size, stride=1, dilation=dilation[0]
  118. ).weight_norm(),
  119. FishConvNet(
  120. channels, channels, kernel_size, stride=1, dilation=dilation[1]
  121. ).weight_norm(),
  122. FishConvNet(
  123. channels, channels, kernel_size, stride=1, dilation=dilation[2]
  124. ).weight_norm(),
  125. ]
  126. )
  127. self.convs1.apply(init_weights)
  128. self.convs2 = nn.ModuleList(
  129. [
  130. FishConvNet(
  131. channels, channels, kernel_size, stride=1, dilation=dilation[0]
  132. ).weight_norm(),
  133. FishConvNet(
  134. channels, channels, kernel_size, stride=1, dilation=dilation[1]
  135. ).weight_norm(),
  136. FishConvNet(
  137. channels, channels, kernel_size, stride=1, dilation=dilation[2]
  138. ).weight_norm(),
  139. ]
  140. )
  141. self.convs2.apply(init_weights)
  142. def forward(self, x):
  143. for c1, c2 in zip(self.convs1, self.convs2):
  144. xt = F.silu(x)
  145. xt = c1(xt)
  146. xt = F.silu(xt)
  147. xt = c2(xt)
  148. x = xt + x
  149. return x
  150. def remove_parametrizations(self):
  151. for conv in self.convs1:
  152. conv.remove_parametrizations()
  153. for conv in self.convs2:
  154. conv.remove_parametrizations()
  155. class ParallelBlock(nn.Module):
  156. def __init__(
  157. self,
  158. channels: int,
  159. kernel_sizes: tuple[int] = (3, 7, 11),
  160. dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  161. ):
  162. super().__init__()
  163. assert len(kernel_sizes) == len(dilation_sizes)
  164. self.blocks = nn.ModuleList()
  165. for k, d in zip(kernel_sizes, dilation_sizes):
  166. self.blocks.append(ResBlock1(channels, k, d))
  167. def forward(self, x):
  168. return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
  169. def remove_parametrizations(self):
  170. for block in self.blocks:
  171. block.remove_parametrizations()
  172. class HiFiGANGenerator(nn.Module):
  173. def __init__(
  174. self,
  175. *,
  176. hop_length: int = 512,
  177. upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
  178. upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
  179. resblock_kernel_sizes: tuple[int] = (3, 7, 11),
  180. resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
  181. num_mels: int = 128,
  182. upsample_initial_channel: int = 512,
  183. pre_conv_kernel_size: int = 7,
  184. post_conv_kernel_size: int = 7,
  185. post_activation: Callable = partial(nn.SiLU, inplace=True),
  186. ):
  187. super().__init__()
  188. assert (
  189. prod(upsample_rates) == hop_length
  190. ), f"hop_length must be {prod(upsample_rates)}"
  191. self.conv_pre = FishConvNet(
  192. num_mels,
  193. upsample_initial_channel,
  194. pre_conv_kernel_size,
  195. stride=1,
  196. ).weight_norm()
  197. self.num_upsamples = len(upsample_rates)
  198. self.num_kernels = len(resblock_kernel_sizes)
  199. self.noise_convs = nn.ModuleList()
  200. self.ups = nn.ModuleList()
  201. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  202. self.ups.append(
  203. FishTransConvNet(
  204. upsample_initial_channel // (2**i),
  205. upsample_initial_channel // (2 ** (i + 1)),
  206. k,
  207. stride=u,
  208. ).weight_norm()
  209. )
  210. self.resblocks = nn.ModuleList()
  211. for i in range(len(self.ups)):
  212. ch = upsample_initial_channel // (2 ** (i + 1))
  213. self.resblocks.append(
  214. ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
  215. )
  216. self.activation_post = post_activation()
  217. self.conv_post = FishConvNet(
  218. ch, 1, post_conv_kernel_size, stride=1
  219. ).weight_norm()
  220. self.ups.apply(init_weights)
  221. self.conv_post.apply(init_weights)
  222. def forward(self, x):
  223. x = self.conv_pre(x)
  224. for i in range(self.num_upsamples):
  225. x = F.silu(x, inplace=True)
  226. x = self.ups[i](x)
  227. if self.training and self.checkpointing:
  228. x = checkpoint(
  229. self.resblocks[i],
  230. x,
  231. use_reentrant=False,
  232. )
  233. else:
  234. x = self.resblocks[i](x)
  235. x = self.activation_post(x)
  236. x = self.conv_post(x)
  237. x = torch.tanh(x)
  238. return x
  239. def remove_parametrizations(self):
  240. for up in self.ups:
  241. up.remove_parametrizations()
  242. for block in self.resblocks:
  243. block.remove_parametrizations()
  244. self.conv_pre.remove_parametrizations()
  245. self.conv_post.remove_parametrizations()
  246. # DropPath copied from timm library
  247. def drop_path(
  248. x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
  249. ):
  250. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  251. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  252. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  253. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  254. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  255. 'survival rate' as the argument.
  256. """ # noqa: E501
  257. if drop_prob == 0.0 or not training:
  258. return x
  259. keep_prob = 1 - drop_prob
  260. shape = (x.shape[0],) + (1,) * (
  261. x.ndim - 1
  262. ) # work with diff dim tensors, not just 2D ConvNets
  263. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  264. if keep_prob > 0.0 and scale_by_keep:
  265. random_tensor.div_(keep_prob)
  266. return x * random_tensor
  267. class DropPath(nn.Module):
  268. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
  269. def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
  270. super(DropPath, self).__init__()
  271. self.drop_prob = drop_prob
  272. self.scale_by_keep = scale_by_keep
  273. def forward(self, x):
  274. return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
  275. def extra_repr(self):
  276. return f"drop_prob={round(self.drop_prob,3):0.3f}"
  277. class LayerNorm(nn.Module):
  278. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  279. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
  280. shape (batch_size, height, width, channels) while channels_first corresponds to inputs
  281. with shape (batch_size, channels, height, width).
  282. """ # noqa: E501
  283. def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
  284. super().__init__()
  285. self.weight = nn.Parameter(torch.ones(normalized_shape))
  286. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  287. self.eps = eps
  288. self.data_format = data_format
  289. if self.data_format not in ["channels_last", "channels_first"]:
  290. raise NotImplementedError
  291. self.normalized_shape = (normalized_shape,)
  292. def forward(self, x):
  293. if self.data_format == "channels_last":
  294. return F.layer_norm(
  295. x, self.normalized_shape, self.weight, self.bias, self.eps
  296. )
  297. elif self.data_format == "channels_first":
  298. u = x.mean(1, keepdim=True)
  299. s = (x - u).pow(2).mean(1, keepdim=True)
  300. x = (x - u) / torch.sqrt(s + self.eps)
  301. x = self.weight[:, None] * x + self.bias[:, None]
  302. return x
  303. # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
  304. class ConvNeXtBlock(nn.Module):
  305. r"""ConvNeXt Block. There are two equivalent implementations:
  306. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  307. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  308. We use (2) as we find it slightly faster in PyTorch
  309. Args:
  310. dim (int): Number of input channels.
  311. drop_path (float): Stochastic depth rate. Default: 0.0
  312. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
  313. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  314. kernel_size (int): Kernel size for depthwise conv. Default: 7.
  315. dilation (int): Dilation for depthwise conv. Default: 1.
  316. """ # noqa: E501
  317. def __init__(
  318. self,
  319. dim: int,
  320. drop_path: float = 0.0,
  321. layer_scale_init_value: float = 1e-6,
  322. mlp_ratio: float = 4.0,
  323. kernel_size: int = 7,
  324. dilation: int = 1,
  325. ):
  326. super().__init__()
  327. self.dwconv = FishConvNet(
  328. dim,
  329. dim,
  330. kernel_size=kernel_size,
  331. # padding=int(dilation * (kernel_size - 1) / 2),
  332. groups=dim,
  333. ) # depthwise conv
  334. self.norm = LayerNorm(dim, eps=1e-6)
  335. self.pwconv1 = nn.Linear(
  336. dim, int(mlp_ratio * dim)
  337. ) # pointwise/1x1 convs, implemented with linear layers
  338. self.act = nn.GELU()
  339. self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
  340. self.gamma = (
  341. nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  342. if layer_scale_init_value > 0
  343. else None
  344. )
  345. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  346. def forward(self, x, apply_residual: bool = True):
  347. input = x
  348. x = self.dwconv(x)
  349. x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
  350. x = self.norm(x)
  351. x = self.pwconv1(x)
  352. x = self.act(x)
  353. x = self.pwconv2(x)
  354. if self.gamma is not None:
  355. x = self.gamma * x
  356. x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
  357. x = self.drop_path(x)
  358. if apply_residual:
  359. x = input + x
  360. return x
  361. class ConvNeXtEncoder(nn.Module):
  362. def __init__(
  363. self,
  364. input_channels: int = 3,
  365. depths: list[int] = [3, 3, 9, 3],
  366. dims: list[int] = [96, 192, 384, 768],
  367. drop_path_rate: float = 0.0,
  368. layer_scale_init_value: float = 1e-6,
  369. kernel_size: int = 7,
  370. ):
  371. super().__init__()
  372. assert len(depths) == len(dims)
  373. self.downsample_layers = nn.ModuleList()
  374. stem = nn.Sequential(
  375. FishConvNet(
  376. input_channels,
  377. dims[0],
  378. kernel_size=7,
  379. # padding=3,
  380. # padding_mode="replicate",
  381. # padding_mode="zeros",
  382. ),
  383. LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
  384. )
  385. self.downsample_layers.append(stem)
  386. for i in range(len(depths) - 1):
  387. mid_layer = nn.Sequential(
  388. LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
  389. nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
  390. )
  391. self.downsample_layers.append(mid_layer)
  392. self.stages = nn.ModuleList()
  393. dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  394. cur = 0
  395. for i in range(len(depths)):
  396. stage = nn.Sequential(
  397. *[
  398. ConvNeXtBlock(
  399. dim=dims[i],
  400. drop_path=dp_rates[cur + j],
  401. layer_scale_init_value=layer_scale_init_value,
  402. kernel_size=kernel_size,
  403. )
  404. for j in range(depths[i])
  405. ]
  406. )
  407. self.stages.append(stage)
  408. cur += depths[i]
  409. self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
  410. self.apply(self._init_weights)
  411. def _init_weights(self, m):
  412. if isinstance(m, (nn.Conv1d, nn.Linear)):
  413. nn.init.trunc_normal_(m.weight, std=0.02)
  414. nn.init.constant_(m.bias, 0)
  415. def forward(
  416. self,
  417. x: torch.Tensor,
  418. ) -> torch.Tensor:
  419. for i in range(len(self.downsample_layers)):
  420. x = self.downsample_layers[i](x)
  421. x = self.stages[i](x)
  422. return self.norm(x)
  423. class FireflyArchitecture(nn.Module):
  424. def __init__(
  425. self,
  426. backbone: nn.Module,
  427. head: nn.Module,
  428. quantizer: nn.Module,
  429. spec_transform: nn.Module,
  430. ):
  431. super().__init__()
  432. self.backbone = backbone
  433. self.head = head
  434. self.quantizer = quantizer
  435. self.spec_transform = spec_transform
  436. self.downsample_factor = math.prod(self.quantizer.downsample_factor)
  437. def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
  438. if self.spec_transform is not None:
  439. x = self.spec_transform(x)
  440. x = self.backbone(x)
  441. if mask is not None:
  442. x = x * mask
  443. if self.quantizer is not None:
  444. vq_result = self.quantizer(x)
  445. x = vq_result.z
  446. if mask is not None:
  447. x = x * mask
  448. x = self.head(x, template=template)
  449. if x.ndim == 2:
  450. x = x[:, None, :]
  451. if self.vq is not None:
  452. return x, vq_result
  453. return x
  454. def encode(self, audios, audio_lengths):
  455. audios = audios.float()
  456. mels = self.spec_transform(audios)
  457. mel_lengths = audio_lengths // self.spec_transform.hop_length
  458. mel_masks = sequence_mask(mel_lengths, mels.shape[2])
  459. mel_masks_float_conv = mel_masks[:, None, :].float()
  460. mels = mels * mel_masks_float_conv
  461. # Encode
  462. encoded_features = self.backbone(mels) * mel_masks_float_conv
  463. feature_lengths = mel_lengths // self.downsample_factor
  464. return self.quantizer.encode(encoded_features), feature_lengths
  465. def decode(self, indices, feature_lengths) -> torch.Tensor:
  466. mel_masks = sequence_mask(
  467. feature_lengths * self.downsample_factor,
  468. indices.shape[2] * self.downsample_factor,
  469. )
  470. mel_masks_float_conv = mel_masks[:, None, :].float()
  471. audio_lengths = (
  472. feature_lengths * self.downsample_factor * self.spec_transform.hop_length
  473. )
  474. audio_masks = sequence_mask(
  475. audio_lengths,
  476. indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
  477. )
  478. audio_masks_float_conv = audio_masks[:, None, :].float()
  479. z = self.quantizer.decode(indices) * mel_masks_float_conv
  480. x = self.head(z) * audio_masks_float_conv
  481. return x, audio_lengths
  482. def remove_parametrizations(self):
  483. if hasattr(self.backbone, "remove_parametrizations"):
  484. self.backbone.remove_parametrizations()
  485. if hasattr(self.head, "remove_parametrizations"):
  486. self.head.remove_parametrizations()
  487. @property
  488. def device(self):
  489. return next(self.parameters()).device