| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- from dataclasses import dataclass
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from einops import rearrange
- from vector_quantize_pytorch import GroupedResidualFSQ
- from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
- @dataclass
- class FSQResult:
- z: torch.Tensor
- codes: torch.Tensor
- latents: torch.Tensor
- class DownsampleFiniteScalarQuantize(nn.Module):
- def __init__(
- self,
- input_dim: int = 512,
- n_codebooks: int = 9,
- n_groups: int = 1,
- levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
- downsample_factor: tuple[int] = (2, 2),
- downsample_dims: tuple[int] | None = None,
- ):
- super().__init__()
- if downsample_dims is None:
- downsample_dims = [input_dim for _ in range(len(downsample_factor))]
- all_dims = (input_dim,) + tuple(downsample_dims)
- self.residual_fsq = GroupedResidualFSQ(
- dim=all_dims[-1],
- levels=levels,
- num_quantizers=n_codebooks,
- groups=n_groups,
- )
- self.downsample_factor = downsample_factor
- self.downsample_dims = downsample_dims
- self.downsample = nn.Sequential(
- *[
- nn.Sequential(
- FishConvNet(
- all_dims[idx],
- all_dims[idx + 1],
- kernel_size=factor,
- stride=factor,
- ),
- ConvNeXtBlock(dim=all_dims[idx + 1]),
- )
- for idx, factor in enumerate(downsample_factor)
- ]
- )
- self.upsample = nn.Sequential(
- *[
- nn.Sequential(
- FishTransConvNet(
- all_dims[idx + 1],
- all_dims[idx],
- kernel_size=factor,
- stride=factor,
- ),
- ConvNeXtBlock(dim=all_dims[idx]),
- )
- for idx, factor in reversed(list(enumerate(downsample_factor)))
- ]
- )
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, (nn.Conv1d, nn.Linear)):
- nn.init.trunc_normal_(m.weight, std=0.02)
- nn.init.constant_(m.bias, 0)
- def forward(self, z) -> FSQResult:
- original_shape = z.shape
- z = self.downsample(z)
- quantized, indices = self.residual_fsq(z.mT)
- result = FSQResult(
- z=quantized.mT,
- codes=indices.mT,
- latents=z,
- )
- result.z = self.upsample(result.z)
- # Pad or crop z to match original shape
- diff = original_shape[-1] - result.z.shape[-1]
- left = diff // 2
- right = diff - left
- if diff > 0:
- result.z = F.pad(result.z, (left, right))
- elif diff < 0:
- result.z = result.z[..., -left:right]
- return result
- def encode(self, z):
- z = self.downsample(z)
- _, indices = self.residual_fsq(z.mT)
- indices = rearrange(indices, "g b l r -> b (g r) l")
- return indices
- def decode(self, indices: torch.Tensor):
- indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
- z_q = self.residual_fsq.get_output_from_indices(indices)
- z_q = self.upsample(z_q.mT)
- return z_q
|