rvq.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from dataclasses import dataclass
  2. from typing import Union
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from einops import rearrange
  8. from torch.nn.utils import weight_norm
  9. from vector_quantize_pytorch import LFQ, ResidualVQ
  10. class DownsampleResidualVectorQuantizer(nn.Module):
  11. """
  12. Downsampled version of ResidualVectorQuantize
  13. """
  14. def __init__(
  15. self,
  16. input_dim: int = 512,
  17. n_codebooks: int = 9,
  18. codebook_size: int = 1024,
  19. codebook_dim: Union[int, list] = 8,
  20. quantizer_dropout: float = 0.0,
  21. min_quantizers: int = 4,
  22. downsample_factor: tuple[int] = (2, 2),
  23. downsample_dims: tuple[int] | None = None,
  24. ):
  25. super().__init__()
  26. if downsample_dims is None:
  27. downsample_dims = [input_dim for _ in range(len(downsample_factor))]
  28. all_dims = (input_dim,) + tuple(downsample_dims)
  29. # self.vq = ResidualVQ(
  30. # dim=all_dims[-1],
  31. # num_quantizers=n_codebooks,
  32. # codebook_dim=codebook_dim,
  33. # threshold_ema_dead_code=2,
  34. # codebook_size=codebook_size,
  35. # kmeans_init=False,
  36. # )
  37. self.vq = LFQ(
  38. dim=all_dims[-1],
  39. codebook_size=2**14,
  40. entropy_loss_weight=0.1,
  41. diversity_gamma=1.0,
  42. )
  43. self.downsample_factor = downsample_factor
  44. self.downsample_dims = downsample_dims
  45. self.downsample = nn.Sequential(
  46. *[
  47. nn.Conv1d(
  48. all_dims[idx],
  49. all_dims[idx + 1],
  50. kernel_size=factor,
  51. stride=factor,
  52. )
  53. for idx, factor in enumerate(downsample_factor)
  54. ]
  55. )
  56. self.upsample = nn.Sequential(
  57. *[
  58. nn.ConvTranspose1d(
  59. all_dims[idx + 1],
  60. all_dims[idx],
  61. kernel_size=factor,
  62. stride=factor,
  63. )
  64. for idx, factor in reversed(list(enumerate(downsample_factor)))
  65. ]
  66. )
  67. def forward(self, z):
  68. original_shape = z.shape
  69. z = self.downsample(z)
  70. z, indices, loss = self.vq(z.mT)
  71. z = self.upsample(z.mT)
  72. loss = loss.mean()
  73. # Pad or crop z to match original shape
  74. diff = original_shape[-1] - z.shape[-1]
  75. left = diff // 2
  76. right = diff - left
  77. if diff > 0:
  78. z = F.pad(z, (left, right))
  79. elif diff < 0:
  80. z = z[..., left:-right]
  81. return z, indices, loss
  82. # def from_codes(self, codes: torch.Tensor):
  83. # z_q, z_p, codes = super().from_codes(codes)
  84. # z_q = self.upsample(z_q)
  85. # return z_q, z_p, codes
  86. # def from_latents(self, latents: torch.Tensor):
  87. # z_q, z_p, codes = super().from_latents(latents)
  88. # z_q = self.upsample(z_q)
  89. # return z_q, z_p, codes
  90. if __name__ == "__main__":
  91. rvq = DownsampleResidualVectorQuantizer(
  92. quantizer_dropout=1.0,
  93. min_quantizers=1,
  94. codebook_size=256,
  95. downsample_factor=(2, 2),
  96. )
  97. x = torch.randn(16, 512, 80)
  98. result = rvq(x)
  99. print(result.latents.shape, result.codes.shape, result.z.shape)
  100. y = rvq.from_codes(result.codes)
  101. print(y[0].shape)
  102. y = rvq.from_latents(result.latents)
  103. print(y[0].shape)