|
|
@@ -7,253 +7,10 @@ import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from einops import rearrange
|
|
|
from torch.nn.utils import weight_norm
|
|
|
+from vector_quantize_pytorch import LFQ, ResidualVQ
|
|
|
|
|
|
|
|
|
-class VectorQuantize(nn.Module):
|
|
|
- """
|
|
|
- Implementation of VQ similar to Karpathy's repo:
|
|
|
- https://github.com/karpathy/deep-vector-quantization
|
|
|
- Additionally uses following tricks from Improved VQGAN
|
|
|
- (https://arxiv.org/pdf/2110.04627.pdf):
|
|
|
- 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
|
|
- for improved codebook usage
|
|
|
- 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
|
|
- improves training stability
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
|
|
- super().__init__()
|
|
|
- self.codebook_size = codebook_size
|
|
|
- self.codebook_dim = codebook_dim
|
|
|
-
|
|
|
- self.in_proj = weight_norm(nn.Conv1d(input_dim, codebook_dim, kernel_size=1))
|
|
|
- self.out_proj = weight_norm(nn.Conv1d(codebook_dim, input_dim, kernel_size=1))
|
|
|
- self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
|
|
-
|
|
|
- def forward(self, z):
|
|
|
- """Quantized the input tensor using a fixed codebook and returns
|
|
|
- the corresponding codebook vectors
|
|
|
-
|
|
|
- Parameters
|
|
|
- ----------
|
|
|
- z : Tensor[B x D x T]
|
|
|
-
|
|
|
- Returns
|
|
|
- -------
|
|
|
- Tensor[B x D x T]
|
|
|
- Quantized continuous representation of input
|
|
|
- Tensor[1]
|
|
|
- Commitment loss to train encoder to predict vectors closer to codebook
|
|
|
- entries
|
|
|
- Tensor[1]
|
|
|
- Codebook loss to update the codebook
|
|
|
- Tensor[B x T]
|
|
|
- Codebook indices (quantized discrete representation of input)
|
|
|
- Tensor[B x D x T]
|
|
|
- Projected latents (continuous representation of input before quantization)
|
|
|
- """
|
|
|
-
|
|
|
- # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
|
|
- z_e = self.in_proj(z) # z_e : (B x D x T)
|
|
|
- z_q, indices = self.decode_latents(z_e)
|
|
|
-
|
|
|
- commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
|
|
- codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
|
|
-
|
|
|
- z_q = (
|
|
|
- z_e + (z_q - z_e).detach()
|
|
|
- ) # noop in forward pass, straight-through gradient estimator in backward pass
|
|
|
-
|
|
|
- z_q = self.out_proj(z_q)
|
|
|
-
|
|
|
- return z_q, commitment_loss, codebook_loss, indices, z_e
|
|
|
-
|
|
|
- def embed_code(self, embed_id):
|
|
|
- return F.embedding(embed_id, self.codebook.weight)
|
|
|
-
|
|
|
- def decode_code(self, embed_id):
|
|
|
- return self.embed_code(embed_id).transpose(1, 2)
|
|
|
-
|
|
|
- def decode_latents(self, latents):
|
|
|
- encodings = rearrange(latents, "b d t -> (b t) d")
|
|
|
- codebook = self.codebook.weight # codebook: (N x D)
|
|
|
-
|
|
|
- # L2 normalize encodings and codebook (ViT-VQGAN)
|
|
|
- encodings = F.normalize(encodings)
|
|
|
- codebook = F.normalize(codebook)
|
|
|
-
|
|
|
- # Compute euclidean distance with codebook
|
|
|
- dist = (
|
|
|
- encodings.pow(2).sum(1, keepdim=True)
|
|
|
- - 2 * encodings @ codebook.t()
|
|
|
- + codebook.pow(2).sum(1, keepdim=True).t()
|
|
|
- )
|
|
|
- indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
|
|
- z_q = self.decode_code(indices)
|
|
|
- return z_q, indices
|
|
|
-
|
|
|
-
|
|
|
-@dataclass
|
|
|
-class VQResult:
|
|
|
- z: torch.Tensor
|
|
|
- codes: torch.Tensor
|
|
|
- latents: torch.Tensor
|
|
|
- commitment_loss: torch.Tensor
|
|
|
- codebook_loss: torch.Tensor
|
|
|
-
|
|
|
-
|
|
|
-class ResidualVectorQuantize(nn.Module):
|
|
|
- """
|
|
|
- Introduced in SoundStream: An end2end neural audio codec
|
|
|
- https://arxiv.org/abs/2107.03312
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- input_dim: int = 512,
|
|
|
- n_codebooks: int = 9,
|
|
|
- codebook_size: int = 1024,
|
|
|
- codebook_dim: Union[int, list] = 8,
|
|
|
- quantizer_dropout: float = 0.0,
|
|
|
- min_quantizers: int = 4,
|
|
|
- ):
|
|
|
- super().__init__()
|
|
|
- if isinstance(codebook_dim, int):
|
|
|
- codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
|
|
-
|
|
|
- self.n_codebooks = n_codebooks
|
|
|
- self.codebook_dim = codebook_dim
|
|
|
- self.codebook_size = codebook_size
|
|
|
-
|
|
|
- self.quantizers = nn.ModuleList(
|
|
|
- [
|
|
|
- VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
|
|
- for i in range(n_codebooks)
|
|
|
- ]
|
|
|
- )
|
|
|
- self.quantizer_dropout = quantizer_dropout
|
|
|
- self.min_quantizers = min_quantizers
|
|
|
-
|
|
|
- def forward(self, z, n_quantizers: int = None) -> VQResult:
|
|
|
- """Quantized the input tensor using a fixed set of `n` codebooks and returns
|
|
|
- the corresponding codebook vectors
|
|
|
- Parameters
|
|
|
- ----------
|
|
|
- z : Tensor[B x D x T]
|
|
|
- n_quantizers : int, optional
|
|
|
- No. of quantizers to use
|
|
|
- (n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
|
|
- Note: if `self.quantizer_dropout` is True, this argument is ignored
|
|
|
- when in training mode, and a random number of quantizers is used.
|
|
|
- Returns
|
|
|
- -------
|
|
|
- """
|
|
|
- z_q = 0
|
|
|
- residual = z
|
|
|
- commitment_loss = 0
|
|
|
- codebook_loss = 0
|
|
|
-
|
|
|
- codebook_indices = []
|
|
|
- latents = []
|
|
|
-
|
|
|
- if n_quantizers is None:
|
|
|
- n_quantizers = self.n_codebooks
|
|
|
-
|
|
|
- if self.training:
|
|
|
- n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
|
|
- dropout = torch.randint(
|
|
|
- self.min_quantizers, self.n_codebooks + 1, (z.shape[0],)
|
|
|
- )
|
|
|
- n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
|
|
- n_quantizers[:n_dropout] = dropout[:n_dropout]
|
|
|
- n_quantizers = n_quantizers.to(z.device)
|
|
|
-
|
|
|
- for i, quantizer in enumerate(self.quantizers):
|
|
|
- if self.training is False and i >= n_quantizers:
|
|
|
- break
|
|
|
-
|
|
|
- z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
|
|
- residual
|
|
|
- )
|
|
|
-
|
|
|
- # Create mask to apply quantizer dropout
|
|
|
- mask = (
|
|
|
- torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
|
|
- )
|
|
|
- z_q = z_q + z_q_i * mask[:, None, None]
|
|
|
- residual = residual - z_q_i
|
|
|
-
|
|
|
- # Sum losses
|
|
|
- commitment_loss += (commitment_loss_i * mask).mean()
|
|
|
- codebook_loss += (codebook_loss_i * mask).mean()
|
|
|
-
|
|
|
- codebook_indices.append(indices_i)
|
|
|
- latents.append(z_e_i)
|
|
|
-
|
|
|
- codes = torch.stack(codebook_indices, dim=1)
|
|
|
- latents = torch.cat(latents, dim=1)
|
|
|
-
|
|
|
- return VQResult(z_q, codes, latents, commitment_loss, codebook_loss)
|
|
|
-
|
|
|
- def from_codes(self, codes: torch.Tensor):
|
|
|
- """Given the quantized codes, reconstruct the continuous representation
|
|
|
- Parameters
|
|
|
- ----------
|
|
|
- codes : Tensor[B x N x T]
|
|
|
- Quantized discrete representation of input
|
|
|
- Returns
|
|
|
- -------
|
|
|
- Tensor[B x D x T]
|
|
|
- Quantized continuous representation of input
|
|
|
- """
|
|
|
- z_q = 0.0
|
|
|
- z_p = []
|
|
|
- n_codebooks = codes.shape[1]
|
|
|
- for i in range(n_codebooks):
|
|
|
- z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
|
|
- z_p.append(z_p_i)
|
|
|
-
|
|
|
- z_q_i = self.quantizers[i].out_proj(z_p_i)
|
|
|
- z_q = z_q + z_q_i
|
|
|
- return z_q, torch.cat(z_p, dim=1), codes
|
|
|
-
|
|
|
- def from_latents(self, latents: torch.Tensor):
|
|
|
- """Given the unquantized latents, reconstruct the
|
|
|
- continuous representation after quantization.
|
|
|
-
|
|
|
- Parameters
|
|
|
- ----------
|
|
|
- latents : Tensor[B x N x T]
|
|
|
- Continuous representation of input after projection
|
|
|
-
|
|
|
- Returns
|
|
|
- -------
|
|
|
- Tensor[B x D x T]
|
|
|
- Quantized representation of full-projected space
|
|
|
- Tensor[B x D x T]
|
|
|
- Quantized representation of latent space
|
|
|
- """
|
|
|
- z_q = 0
|
|
|
- z_p = []
|
|
|
- codes = []
|
|
|
- dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
|
|
-
|
|
|
- n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
|
|
- 0
|
|
|
- ]
|
|
|
- for i in range(n_codebooks):
|
|
|
- j, k = dims[i], dims[i + 1]
|
|
|
- z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
|
|
- z_p.append(z_p_i)
|
|
|
- codes.append(codes_i)
|
|
|
-
|
|
|
- z_q_i = self.quantizers[i].out_proj(z_p_i)
|
|
|
- z_q = z_q + z_q_i
|
|
|
-
|
|
|
- return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
|
|
-
|
|
|
-
|
|
|
-class DownsampleResidualVectorQuantizer(ResidualVectorQuantize):
|
|
|
+class DownsampleResidualVectorQuantizer(nn.Module):
|
|
|
"""
|
|
|
Downsampled version of ResidualVectorQuantize
|
|
|
"""
|
|
|
@@ -269,18 +26,26 @@ class DownsampleResidualVectorQuantizer(ResidualVectorQuantize):
|
|
|
downsample_factor: tuple[int] = (2, 2),
|
|
|
downsample_dims: tuple[int] | None = None,
|
|
|
):
|
|
|
+ super().__init__()
|
|
|
if downsample_dims is None:
|
|
|
downsample_dims = [input_dim for _ in range(len(downsample_factor))]
|
|
|
|
|
|
all_dims = (input_dim,) + tuple(downsample_dims)
|
|
|
|
|
|
- super().__init__(
|
|
|
- all_dims[-1],
|
|
|
- n_codebooks,
|
|
|
- codebook_size,
|
|
|
- codebook_dim,
|
|
|
- quantizer_dropout,
|
|
|
- min_quantizers,
|
|
|
+ # self.vq = ResidualVQ(
|
|
|
+ # dim=all_dims[-1],
|
|
|
+ # num_quantizers=n_codebooks,
|
|
|
+ # codebook_dim=codebook_dim,
|
|
|
+ # threshold_ema_dead_code=2,
|
|
|
+ # codebook_size=codebook_size,
|
|
|
+ # kmeans_init=False,
|
|
|
+ # )
|
|
|
+
|
|
|
+ self.vq = LFQ(
|
|
|
+ dim=all_dims[-1],
|
|
|
+ codebook_size=2**14,
|
|
|
+ entropy_loss_weight=0.1,
|
|
|
+ diversity_gamma=1.0,
|
|
|
)
|
|
|
|
|
|
self.downsample_factor = downsample_factor
|
|
|
@@ -310,33 +75,34 @@ class DownsampleResidualVectorQuantizer(ResidualVectorQuantize):
|
|
|
]
|
|
|
)
|
|
|
|
|
|
- def forward(self, z, n_quantizers: int = None) -> VQResult:
|
|
|
+ def forward(self, z):
|
|
|
original_shape = z.shape
|
|
|
z = self.downsample(z)
|
|
|
- result = super().forward(z, n_quantizers)
|
|
|
- result.z = self.upsample(result.z)
|
|
|
+ z, indices, loss = self.vq(z.mT)
|
|
|
+ z = self.upsample(z.mT)
|
|
|
+ loss = loss.mean()
|
|
|
|
|
|
# Pad or crop z to match original shape
|
|
|
- diff = original_shape[-1] - result.z.shape[-1]
|
|
|
+ diff = original_shape[-1] - z.shape[-1]
|
|
|
left = diff // 2
|
|
|
right = diff - left
|
|
|
|
|
|
if diff > 0:
|
|
|
- result.z = F.pad(result.z, (left, right))
|
|
|
+ z = F.pad(z, (left, right))
|
|
|
elif diff < 0:
|
|
|
- result.z = result.z[..., left:-right]
|
|
|
+ z = z[..., left:-right]
|
|
|
|
|
|
- return result
|
|
|
+ return z, indices, loss
|
|
|
|
|
|
- def from_codes(self, codes: torch.Tensor):
|
|
|
- z_q, z_p, codes = super().from_codes(codes)
|
|
|
- z_q = self.upsample(z_q)
|
|
|
- return z_q, z_p, codes
|
|
|
+ # def from_codes(self, codes: torch.Tensor):
|
|
|
+ # z_q, z_p, codes = super().from_codes(codes)
|
|
|
+ # z_q = self.upsample(z_q)
|
|
|
+ # return z_q, z_p, codes
|
|
|
|
|
|
- def from_latents(self, latents: torch.Tensor):
|
|
|
- z_q, z_p, codes = super().from_latents(latents)
|
|
|
- z_q = self.upsample(z_q)
|
|
|
- return z_q, z_p, codes
|
|
|
+ # def from_latents(self, latents: torch.Tensor):
|
|
|
+ # z_q, z_p, codes = super().from_latents(latents)
|
|
|
+ # z_q = self.upsample(z_q)
|
|
|
+ # return z_q, z_p, codes
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|