fsq.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from dataclasses import dataclass
  2. from typing import Union
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from einops import rearrange
  8. from torch.nn.utils import weight_norm
  9. from vector_quantize_pytorch import GroupedResidualFSQ
  10. from .firefly import ConvNeXtBlock
  11. @dataclass
  12. class FSQResult:
  13. z: torch.Tensor
  14. codes: torch.Tensor
  15. latents: torch.Tensor
  16. class DownsampleFiniteScalarQuantize(nn.Module):
  17. def __init__(
  18. self,
  19. input_dim: int = 512,
  20. n_codebooks: int = 9,
  21. n_groups: int = 1,
  22. levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
  23. downsample_factor: tuple[int] = (2, 2),
  24. downsample_dims: tuple[int] | None = None,
  25. ):
  26. super().__init__()
  27. if downsample_dims is None:
  28. downsample_dims = [input_dim for _ in range(len(downsample_factor))]
  29. all_dims = (input_dim,) + tuple(downsample_dims)
  30. self.residual_fsq = GroupedResidualFSQ(
  31. dim=all_dims[-1],
  32. levels=levels,
  33. num_quantizers=n_codebooks,
  34. groups=n_groups,
  35. )
  36. self.downsample_factor = downsample_factor
  37. self.downsample_dims = downsample_dims
  38. self.downsample = nn.Sequential(
  39. *[
  40. nn.Sequential(
  41. nn.Conv1d(
  42. all_dims[idx],
  43. all_dims[idx + 1],
  44. kernel_size=factor,
  45. stride=factor,
  46. ),
  47. ConvNeXtBlock(dim=all_dims[idx + 1]),
  48. )
  49. for idx, factor in enumerate(downsample_factor)
  50. ]
  51. )
  52. self.upsample = nn.Sequential(
  53. *[
  54. nn.Sequential(
  55. nn.ConvTranspose1d(
  56. all_dims[idx + 1],
  57. all_dims[idx],
  58. kernel_size=factor,
  59. stride=factor,
  60. ),
  61. ConvNeXtBlock(dim=all_dims[idx]),
  62. )
  63. for idx, factor in reversed(list(enumerate(downsample_factor)))
  64. ]
  65. )
  66. self.apply(self._init_weights)
  67. def _init_weights(self, m):
  68. if isinstance(m, (nn.Conv1d, nn.Linear)):
  69. nn.init.trunc_normal_(m.weight, std=0.02)
  70. nn.init.constant_(m.bias, 0)
  71. def forward(self, z) -> FSQResult:
  72. original_shape = z.shape
  73. z = self.downsample(z)
  74. quantized, indices = self.residual_fsq(z.mT)
  75. result = FSQResult(
  76. z=quantized.mT,
  77. codes=indices.mT,
  78. latents=z,
  79. )
  80. result.z = self.upsample(result.z)
  81. # Pad or crop z to match original shape
  82. diff = original_shape[-1] - result.z.shape[-1]
  83. left = diff // 2
  84. right = diff - left
  85. if diff > 0:
  86. result.z = F.pad(result.z, (left, right))
  87. elif diff < 0:
  88. result.z = result.z[..., left:-right]
  89. return result
  90. # def from_codes(self, codes: torch.Tensor):
  91. # z_q, z_p, codes = self.residual_fsq.get_output_from_indices(codes)
  92. # z_q = self.upsample(z_q)
  93. # return z_q, z_p, codes
  94. # def from_latents(self, latents: torch.Tensor):
  95. # z_q, z_p, codes = super().from_latents(latents)
  96. # z_q = self.upsample(z_q)
  97. # return z_q, z_p, codes
  98. if __name__ == "__main__":
  99. rvq = DownsampleFiniteScalarQuantize(
  100. n_codebooks=1,
  101. downsample_factor=(2, 2),
  102. )
  103. x = torch.randn(16, 512, 80)
  104. result = rvq(x)
  105. print(rvq)
  106. print(result.latents.shape, result.codes.shape, result.z.shape)
  107. # y = rvq.from_codes(result.codes)
  108. # print(y[0].shape)
  109. # y = rvq.from_latents(result.latents)
  110. # print(y[0].shape)