fsq.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. from dataclasses import dataclass
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from einops import rearrange
  6. from vector_quantize_pytorch import GroupedResidualFSQ
  7. from .firefly import ConvNeXtBlock
  8. @dataclass
  9. class FSQResult:
  10. z: torch.Tensor
  11. codes: torch.Tensor
  12. latents: torch.Tensor
  13. class DownsampleFiniteScalarQuantize(nn.Module):
  14. def __init__(
  15. self,
  16. input_dim: int = 512,
  17. n_codebooks: int = 9,
  18. n_groups: int = 1,
  19. levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
  20. downsample_factor: tuple[int] = (2, 2),
  21. downsample_dims: tuple[int] | None = None,
  22. ):
  23. super().__init__()
  24. if downsample_dims is None:
  25. downsample_dims = [input_dim for _ in range(len(downsample_factor))]
  26. all_dims = (input_dim,) + tuple(downsample_dims)
  27. self.residual_fsq = GroupedResidualFSQ(
  28. dim=all_dims[-1],
  29. levels=levels,
  30. num_quantizers=n_codebooks,
  31. groups=n_groups,
  32. )
  33. self.downsample_factor = downsample_factor
  34. self.downsample_dims = downsample_dims
  35. self.downsample = nn.Sequential(
  36. *[
  37. nn.Sequential(
  38. nn.Conv1d(
  39. all_dims[idx],
  40. all_dims[idx + 1],
  41. kernel_size=factor,
  42. stride=factor,
  43. ),
  44. ConvNeXtBlock(dim=all_dims[idx + 1]),
  45. )
  46. for idx, factor in enumerate(downsample_factor)
  47. ]
  48. )
  49. self.upsample = nn.Sequential(
  50. *[
  51. nn.Sequential(
  52. nn.ConvTranspose1d(
  53. all_dims[idx + 1],
  54. all_dims[idx],
  55. kernel_size=factor,
  56. stride=factor,
  57. ),
  58. ConvNeXtBlock(dim=all_dims[idx]),
  59. )
  60. for idx, factor in reversed(list(enumerate(downsample_factor)))
  61. ]
  62. )
  63. self.apply(self._init_weights)
  64. def _init_weights(self, m):
  65. if isinstance(m, (nn.Conv1d, nn.Linear)):
  66. nn.init.trunc_normal_(m.weight, std=0.02)
  67. nn.init.constant_(m.bias, 0)
  68. def forward(self, z) -> FSQResult:
  69. original_shape = z.shape
  70. z = self.downsample(z)
  71. quantized, indices = self.residual_fsq(z.mT)
  72. result = FSQResult(
  73. z=quantized.mT,
  74. codes=indices.mT,
  75. latents=z,
  76. )
  77. result.z = self.upsample(result.z)
  78. # Pad or crop z to match original shape
  79. diff = original_shape[-1] - result.z.shape[-1]
  80. left = diff // 2
  81. right = diff - left
  82. if diff > 0:
  83. result.z = F.pad(result.z, (left, right))
  84. elif diff < 0:
  85. result.z = result.z[..., left:-right]
  86. return result
  87. def encode(self, z):
  88. z = self.downsample(z)
  89. _, indices = self.residual_fsq(z.mT)
  90. indices = rearrange(indices, "g b l r -> b (g r) l")
  91. return indices
  92. def decode(self, indices: torch.Tensor):
  93. indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
  94. z_q = self.residual_fsq.get_output_from_indices(indices)
  95. z_q = self.upsample(z_q.mT)
  96. return z_q
  97. # def from_latents(self, latents: torch.Tensor):
  98. # z_q, z_p, codes = super().from_latents(latents)
  99. # z_q = self.upsample(z_q)
  100. # return z_q, z_p, codes
  101. if __name__ == "__main__":
  102. rvq = DownsampleFiniteScalarQuantize(
  103. n_codebooks=1,
  104. downsample_factor=(2, 2),
  105. )
  106. x = torch.randn(16, 512, 80)
  107. result = rvq(x)
  108. print(rvq)
  109. print(result.latents.shape, result.codes.shape, result.z.shape)
  110. # y = rvq.from_codes(result.codes)
  111. # print(y[0].shape)
  112. # y = rvq.from_latents(result.latents)
  113. # print(y[0].shape)