rvq.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import math
  2. import typing as tp
  3. from dataclasses import dataclass
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from dac.nn.quantize import ResidualVectorQuantize
  8. from torch.nn.utils.parametrizations import weight_norm
  9. from torch.nn.utils.parametrize import remove_parametrizations
  10. def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
  11. """Remove padding from x, handling properly zero padding. Only for 1d!"""
  12. padding_left, padding_right = paddings
  13. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  14. assert (padding_left + padding_right) <= x.shape[-1]
  15. end = x.shape[-1] - padding_right
  16. return x[..., padding_left:end]
  17. def get_extra_padding_for_conv1d(
  18. x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
  19. ) -> int:
  20. """See `pad_for_conv1d`."""
  21. length = x.shape[-1]
  22. n_frames = (length - kernel_size + padding_total) / stride + 1
  23. ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
  24. return ideal_length - length
  25. def pad1d(
  26. x: torch.Tensor,
  27. paddings: tp.Tuple[int, int],
  28. mode: str = "zeros",
  29. value: float = 0.0,
  30. ):
  31. """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
  32. If this is the case, we insert extra 0 padding to the right
  33. before the reflection happen.
  34. """
  35. length = x.shape[-1]
  36. padding_left, padding_right = paddings
  37. assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
  38. if mode == "reflect":
  39. max_pad = max(padding_left, padding_right)
  40. extra_pad = 0
  41. if length <= max_pad:
  42. extra_pad = max_pad - length + 1
  43. x = F.pad(x, (0, extra_pad))
  44. padded = F.pad(x, paddings, mode, value)
  45. end = padded.shape[-1] - extra_pad
  46. return padded[..., :end]
  47. else:
  48. return F.pad(x, paddings, mode, value)
  49. class CausalConvNet(nn.Module):
  50. def __init__(
  51. self,
  52. in_channels,
  53. out_channels,
  54. kernel_size,
  55. dilation=1,
  56. stride=1,
  57. groups=1,
  58. padding=None,
  59. ):
  60. super(CausalConvNet, self).__init__()
  61. self.conv = nn.Conv1d(
  62. in_channels,
  63. out_channels,
  64. kernel_size,
  65. stride=stride,
  66. dilation=dilation,
  67. groups=groups,
  68. )
  69. self.stride = stride
  70. self.kernel_size = (kernel_size - 1) * dilation + 1
  71. self.dilation = dilation
  72. self.padding = self.kernel_size - self.stride
  73. def forward(self, x):
  74. pad = self.padding
  75. extra_padding = get_extra_padding_for_conv1d(
  76. x, self.kernel_size, self.stride, pad
  77. )
  78. x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
  79. return self.conv(x).contiguous()
  80. def weight_norm(self, name="weight", dim=0):
  81. self.conv = weight_norm(self.conv, name=name, dim=dim)
  82. return self
  83. def remove_weight_norm(self):
  84. self.conv = remove_parametrizations(self.conv)
  85. return self
  86. class CausalTransConvNet(nn.Module):
  87. def __init__(
  88. self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
  89. ):
  90. super(CausalTransConvNet, self).__init__()
  91. self.conv = nn.ConvTranspose1d(
  92. in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
  93. )
  94. self.stride = stride
  95. self.kernel_size = kernel_size
  96. def forward(self, x):
  97. x = self.conv(x)
  98. pad = self.kernel_size - self.stride
  99. padding_right = math.ceil(pad)
  100. padding_left = pad - padding_right
  101. x = unpad1d(x, (padding_left, padding_right))
  102. return x.contiguous()
  103. def weight_norm(self, name="weight", dim=0):
  104. self.conv = weight_norm(self.conv, name=name, dim=dim)
  105. return self
  106. def remove_weight_norm(self):
  107. self.conv = remove_parametrizations(self.conv)
  108. return self
  109. # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
  110. class ConvNeXtBlock(nn.Module):
  111. r"""ConvNeXt Block. There are two equivalent implementations:
  112. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  113. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  114. We use (2) as we find it slightly faster in PyTorch
  115. Args:
  116. dim (int): Number of input channels.
  117. drop_path (float): Stochastic depth rate. Default: 0.0
  118. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
  119. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  120. kernel_size (int): Kernel size for depthwise conv. Default: 7.
  121. dilation (int): Dilation for depthwise conv. Default: 1.
  122. """ # noqa: E501
  123. def __init__(
  124. self,
  125. dim: int,
  126. layer_scale_init_value: float = 1e-6,
  127. mlp_ratio: float = 4.0,
  128. kernel_size: int = 7,
  129. dilation: int = 1,
  130. ):
  131. super().__init__()
  132. convnet_type = CausalConvNet
  133. self.dwconv = convnet_type(
  134. dim,
  135. dim,
  136. kernel_size=kernel_size,
  137. # padding=int(dilation * (kernel_size - 1) / 2),
  138. groups=dim,
  139. dilation=dilation,
  140. ) # depthwise conv
  141. self.norm = nn.LayerNorm(dim, eps=1e-6)
  142. self.pwconv1 = nn.Linear(
  143. dim, int(mlp_ratio * dim)
  144. ) # pointwise/1x1 convs, implemented with linear layers
  145. self.act = nn.GELU()
  146. self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
  147. self.gamma = (
  148. nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  149. if layer_scale_init_value > 0
  150. else None
  151. )
  152. def forward(self, x, apply_residual: bool = True):
  153. input = x
  154. x = self.dwconv(x)
  155. x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
  156. x = self.norm(x)
  157. x = self.pwconv1(x)
  158. x = self.act(x)
  159. x = self.pwconv2(x)
  160. if self.gamma is not None:
  161. x = self.gamma * x
  162. x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
  163. if apply_residual:
  164. x = input + x
  165. return x
  166. @dataclass
  167. class VQResult:
  168. z: torch.Tensor
  169. codes: torch.Tensor
  170. latents: torch.Tensor
  171. codebook_loss: torch.Tensor
  172. commitment_loss: torch.Tensor
  173. semantic_distill_z: torch.Tensor | None = None
  174. class DownsampleResidualVectorQuantize(nn.Module):
  175. def __init__(
  176. self,
  177. input_dim: int = 1024,
  178. n_codebooks: int = 9,
  179. codebook_dim: int = 8,
  180. quantizer_dropout: float = 0.5,
  181. codebook_size: int = 1024,
  182. semantic_codebook_size: int = 4096,
  183. downsample_factor: tuple[int] = (2, 2),
  184. downsample_dims: tuple[int] | None = None,
  185. pre_module: nn.Module | None = None,
  186. post_module: nn.Module | None = None,
  187. semantic_predictor_module: nn.Module | None = None,
  188. ):
  189. super().__init__()
  190. if downsample_dims is None:
  191. downsample_dims = [input_dim for _ in range(len(downsample_factor))]
  192. all_dims = (input_dim,) + tuple(downsample_dims)
  193. self.semantic_quantizer = ResidualVectorQuantize(
  194. input_dim=input_dim,
  195. n_codebooks=1,
  196. codebook_size=semantic_codebook_size,
  197. codebook_dim=codebook_dim,
  198. quantizer_dropout=0.0,
  199. )
  200. self.quantizer = ResidualVectorQuantize(
  201. input_dim=input_dim,
  202. n_codebooks=n_codebooks,
  203. codebook_size=codebook_size,
  204. codebook_dim=codebook_dim,
  205. quantizer_dropout=quantizer_dropout,
  206. )
  207. self.downsample_factor = downsample_factor
  208. self.downsample_dims = downsample_dims
  209. convnet_type = CausalConvNet
  210. transconvnet_type = CausalTransConvNet
  211. self.downsample = nn.Sequential(
  212. *[
  213. nn.Sequential(
  214. convnet_type(
  215. all_dims[idx],
  216. all_dims[idx + 1],
  217. kernel_size=factor,
  218. stride=factor,
  219. ),
  220. ConvNeXtBlock(dim=all_dims[idx + 1]),
  221. )
  222. for idx, factor in enumerate(downsample_factor)
  223. ]
  224. )
  225. self.upsample = nn.Sequential(
  226. *[
  227. nn.Sequential(
  228. transconvnet_type(
  229. all_dims[idx + 1],
  230. all_dims[idx],
  231. kernel_size=factor,
  232. stride=factor,
  233. ),
  234. ConvNeXtBlock(dim=all_dims[idx]),
  235. )
  236. for idx, factor in reversed(list(enumerate(downsample_factor)))
  237. ]
  238. )
  239. self.apply(self._init_weights)
  240. self.pre_module = (
  241. pre_module if pre_module is not None else nn.Identity()
  242. ) # leave for transformer, LSTM or Mamba or something else
  243. self.post_module = post_module if post_module is not None else nn.Identity()
  244. self.semantic_predictor_module = (
  245. semantic_predictor_module
  246. if semantic_predictor_module is not None
  247. else nn.Identity()
  248. )
  249. def _init_weights(self, m):
  250. if isinstance(m, (nn.Conv1d, nn.Linear)):
  251. nn.init.trunc_normal_(m.weight, std=0.02)
  252. nn.init.constant_(m.bias, 0)
  253. def forward(
  254. self, z, n_quantizers: int = None, semantic_len: torch.Tensor = None, **kwargs
  255. ):
  256. # z: (B, D, T)
  257. original_shape = z.shape
  258. if semantic_len is None:
  259. semantic_len = torch.LongTensor([z.shape[-1]])
  260. z = self.downsample(z)
  261. z = self.pre_module(z) # B, T, D
  262. (
  263. semantic_z,
  264. semantic_codes,
  265. semantic_latents,
  266. semantic_commitment_loss,
  267. semantic_codebook_loss,
  268. ) = self.semantic_quantizer(z)
  269. residual_z = z - semantic_z
  270. residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
  271. residual_z, n_quantizers=n_quantizers
  272. )
  273. z = semantic_z + residual_z
  274. commitment_loss = commitment_loss + semantic_commitment_loss
  275. codebook_loss = codebook_loss + semantic_codebook_loss
  276. codes = torch.cat([semantic_codes, codes], dim=1)
  277. latents = torch.cat([semantic_latents, latents], dim=1)
  278. z = self.post_module(z)
  279. z = self.upsample(z)
  280. # z: (B, D, T)
  281. # semantic distillation (disabled here since only used in training)
  282. # semantic_distill_z = self.semantic_predictor_module(semantic_z, semantic_len).mT # wav2vec target is B, T, D
  283. # Pad or crop z to match original shape
  284. diff = original_shape[-1] - z.shape[-1]
  285. right = 0
  286. left = abs(diff) - right
  287. if diff > 0:
  288. z = F.pad(z, (left, right))
  289. elif diff < 0:
  290. z = z[..., left:]
  291. results = VQResult(
  292. z=z,
  293. codes=codes,
  294. latents=latents,
  295. commitment_loss=commitment_loss,
  296. codebook_loss=codebook_loss,
  297. )
  298. return results
  299. # def encode(self, z):
  300. # z = self.downsample(z)
  301. # z = self.pre_module(z)
  302. # _, indices, _, _, _ = self.quantizer(z.mT)
  303. # indices = rearrange(indices, "g b l r -> b (g r) l")
  304. # return indices
  305. #
  306. def decode(self, indices: torch.Tensor):
  307. # indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
  308. indices[:, 0] = torch.clamp(
  309. indices[:, 0], max=self.semantic_quantizer.codebook_size - 1
  310. )
  311. indices[:, 1:] = torch.clamp(
  312. indices[:, 1:], max=self.quantizer.codebook_size - 1
  313. )
  314. z_q_semantic = self.semantic_quantizer.from_codes(indices[:, :1])[0]
  315. z_q_residual = self.quantizer.from_codes(indices[:, 1:])[0]
  316. z_q = z_q_semantic + z_q_residual
  317. z_q = self.post_module(z_q)
  318. z_q = self.upsample(z_q)
  319. return z_q
  320. # def from_latents(self, latents: torch.Tensor):
  321. # z_q, z_p, codes = super().from_latents(latents)
  322. # z_q = self.upsample(z_q)
  323. # return z_q, z_p, codes
  324. if __name__ == "__main__":
  325. rvq = DownsampleResidualVectorQuantize(
  326. input_dim=512,
  327. n_codebooks=8,
  328. codebook_dim=8,
  329. codebook_size=1024,
  330. quantizer_dropout=0.5,
  331. downsample_factor=[2, 2],
  332. )
  333. rvq.eval()
  334. x = torch.randn(2, 512, 442)
  335. result = rvq(x)
  336. print(rvq)
  337. print(result.latents.shape, result.codes.shape, result.z.shape)
  338. # y = rvq.from_codes(result.codes)
  339. # print(y[0].shape)
  340. # y = rvq.from_latents(
  341. result1 = rvq(x[:, :, :40])
  342. print(result1.latents.shape, result1.codes.shape, result1.z.shape)
  343. assert torch.allclose(result.z[:, :, :40], result1.z, atol=1e-8)
  344. print("Success")