| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399 |
- import math
- import typing as tp
- from dataclasses import dataclass
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from dac.nn.quantize import ResidualVectorQuantize
- from torch.nn.utils.parametrizations import weight_norm
- from torch.nn.utils.parametrize import remove_parametrizations
- def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
- """Remove padding from x, handling properly zero padding. Only for 1d!"""
- padding_left, padding_right = paddings
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
- assert (padding_left + padding_right) <= x.shape[-1]
- end = x.shape[-1] - padding_right
- return x[..., padding_left:end]
- def get_extra_padding_for_conv1d(
- x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
- ) -> int:
- """See `pad_for_conv1d`."""
- length = x.shape[-1]
- n_frames = (length - kernel_size + padding_total) / stride + 1
- ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
- return ideal_length - length
- def pad1d(
- x: torch.Tensor,
- paddings: tp.Tuple[int, int],
- mode: str = "zeros",
- value: float = 0.0,
- ):
- """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
- If this is the case, we insert extra 0 padding to the right
- before the reflection happen.
- """
- length = x.shape[-1]
- padding_left, padding_right = paddings
- assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
- if mode == "reflect":
- max_pad = max(padding_left, padding_right)
- extra_pad = 0
- if length <= max_pad:
- extra_pad = max_pad - length + 1
- x = F.pad(x, (0, extra_pad))
- padded = F.pad(x, paddings, mode, value)
- end = padded.shape[-1] - extra_pad
- return padded[..., :end]
- else:
- return F.pad(x, paddings, mode, value)
- class CausalConvNet(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- dilation=1,
- stride=1,
- groups=1,
- padding=None,
- ):
- super(CausalConvNet, self).__init__()
- self.conv = nn.Conv1d(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- dilation=dilation,
- groups=groups,
- )
- self.stride = stride
- self.kernel_size = (kernel_size - 1) * dilation + 1
- self.dilation = dilation
- self.padding = self.kernel_size - self.stride
- def forward(self, x):
- pad = self.padding
- extra_padding = get_extra_padding_for_conv1d(
- x, self.kernel_size, self.stride, pad
- )
- x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
- return self.conv(x).contiguous()
- def weight_norm(self, name="weight", dim=0):
- self.conv = weight_norm(self.conv, name=name, dim=dim)
- return self
- def remove_weight_norm(self):
- self.conv = remove_parametrizations(self.conv)
- return self
- class CausalTransConvNet(nn.Module):
- def __init__(
- self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=None
- ):
- super(CausalTransConvNet, self).__init__()
- self.conv = nn.ConvTranspose1d(
- in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
- )
- self.stride = stride
- self.kernel_size = kernel_size
- def forward(self, x):
- x = self.conv(x)
- pad = self.kernel_size - self.stride
- padding_right = math.ceil(pad)
- padding_left = pad - padding_right
- x = unpad1d(x, (padding_left, padding_right))
- return x.contiguous()
- def weight_norm(self, name="weight", dim=0):
- self.conv = weight_norm(self.conv, name=name, dim=dim)
- return self
- def remove_weight_norm(self):
- self.conv = remove_parametrizations(self.conv)
- return self
- # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
- class ConvNeXtBlock(nn.Module):
- r"""ConvNeXt Block. There are two equivalent implementations:
- (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
- (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
- We use (2) as we find it slightly faster in PyTorch
- Args:
- dim (int): Number of input channels.
- drop_path (float): Stochastic depth rate. Default: 0.0
- layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
- kernel_size (int): Kernel size for depthwise conv. Default: 7.
- dilation (int): Dilation for depthwise conv. Default: 1.
- """ # noqa: E501
- def __init__(
- self,
- dim: int,
- layer_scale_init_value: float = 1e-6,
- mlp_ratio: float = 4.0,
- kernel_size: int = 7,
- dilation: int = 1,
- ):
- super().__init__()
- convnet_type = CausalConvNet
- self.dwconv = convnet_type(
- dim,
- dim,
- kernel_size=kernel_size,
- # padding=int(dilation * (kernel_size - 1) / 2),
- groups=dim,
- dilation=dilation,
- ) # depthwise conv
- self.norm = nn.LayerNorm(dim, eps=1e-6)
- self.pwconv1 = nn.Linear(
- dim, int(mlp_ratio * dim)
- ) # pointwise/1x1 convs, implemented with linear layers
- self.act = nn.GELU()
- self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
- self.gamma = (
- nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
- if layer_scale_init_value > 0
- else None
- )
- def forward(self, x, apply_residual: bool = True):
- input = x
- x = self.dwconv(x)
- x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
- x = self.norm(x)
- x = self.pwconv1(x)
- x = self.act(x)
- x = self.pwconv2(x)
- if self.gamma is not None:
- x = self.gamma * x
- x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
- if apply_residual:
- x = input + x
- return x
- @dataclass
- class VQResult:
- z: torch.Tensor
- codes: torch.Tensor
- latents: torch.Tensor
- codebook_loss: torch.Tensor
- commitment_loss: torch.Tensor
- semantic_distill_z: torch.Tensor | None = None
- class DownsampleResidualVectorQuantize(nn.Module):
- def __init__(
- self,
- input_dim: int = 1024,
- n_codebooks: int = 9,
- codebook_dim: int = 8,
- quantizer_dropout: float = 0.5,
- codebook_size: int = 1024,
- semantic_codebook_size: int = 4096,
- downsample_factor: tuple[int] = (2, 2),
- downsample_dims: tuple[int] | None = None,
- pre_module: nn.Module | None = None,
- post_module: nn.Module | None = None,
- semantic_predictor_module: nn.Module | None = None,
- ):
- super().__init__()
- if downsample_dims is None:
- downsample_dims = [input_dim for _ in range(len(downsample_factor))]
- all_dims = (input_dim,) + tuple(downsample_dims)
- self.semantic_quantizer = ResidualVectorQuantize(
- input_dim=input_dim,
- n_codebooks=1,
- codebook_size=semantic_codebook_size,
- codebook_dim=codebook_dim,
- quantizer_dropout=0.0,
- )
- self.quantizer = ResidualVectorQuantize(
- input_dim=input_dim,
- n_codebooks=n_codebooks,
- codebook_size=codebook_size,
- codebook_dim=codebook_dim,
- quantizer_dropout=quantizer_dropout,
- )
- self.downsample_factor = downsample_factor
- self.downsample_dims = downsample_dims
- convnet_type = CausalConvNet
- transconvnet_type = CausalTransConvNet
- self.downsample = nn.Sequential(
- *[
- nn.Sequential(
- convnet_type(
- all_dims[idx],
- all_dims[idx + 1],
- kernel_size=factor,
- stride=factor,
- ),
- ConvNeXtBlock(dim=all_dims[idx + 1]),
- )
- for idx, factor in enumerate(downsample_factor)
- ]
- )
- self.upsample = nn.Sequential(
- *[
- nn.Sequential(
- transconvnet_type(
- all_dims[idx + 1],
- all_dims[idx],
- kernel_size=factor,
- stride=factor,
- ),
- ConvNeXtBlock(dim=all_dims[idx]),
- )
- for idx, factor in reversed(list(enumerate(downsample_factor)))
- ]
- )
- self.apply(self._init_weights)
- self.pre_module = (
- pre_module if pre_module is not None else nn.Identity()
- ) # leave for transformer, LSTM or Mamba or something else
- self.post_module = post_module if post_module is not None else nn.Identity()
- self.semantic_predictor_module = (
- semantic_predictor_module
- if semantic_predictor_module is not None
- else nn.Identity()
- )
- def _init_weights(self, m):
- if isinstance(m, (nn.Conv1d, nn.Linear)):
- nn.init.trunc_normal_(m.weight, std=0.02)
- nn.init.constant_(m.bias, 0)
- def forward(
- self, z, n_quantizers: int = None, semantic_len: torch.Tensor = None, **kwargs
- ):
- # z: (B, D, T)
- original_shape = z.shape
- if semantic_len is None:
- semantic_len = torch.LongTensor([z.shape[-1]])
- z = self.downsample(z)
- z = self.pre_module(z) # B, T, D
- (
- semantic_z,
- semantic_codes,
- semantic_latents,
- semantic_commitment_loss,
- semantic_codebook_loss,
- ) = self.semantic_quantizer(z)
- residual_z = z - semantic_z
- residual_z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
- residual_z, n_quantizers=n_quantizers
- )
- z = semantic_z + residual_z
- commitment_loss = commitment_loss + semantic_commitment_loss
- codebook_loss = codebook_loss + semantic_codebook_loss
- codes = torch.cat([semantic_codes, codes], dim=1)
- latents = torch.cat([semantic_latents, latents], dim=1)
- z = self.post_module(z)
- z = self.upsample(z)
- # z: (B, D, T)
- # semantic distillation (disabled here since only used in training)
- # semantic_distill_z = self.semantic_predictor_module(semantic_z, semantic_len).mT # wav2vec target is B, T, D
- # Pad or crop z to match original shape
- diff = original_shape[-1] - z.shape[-1]
- right = 0
- left = abs(diff) - right
- if diff > 0:
- z = F.pad(z, (left, right))
- elif diff < 0:
- z = z[..., left:]
- results = VQResult(
- z=z,
- codes=codes,
- latents=latents,
- commitment_loss=commitment_loss,
- codebook_loss=codebook_loss,
- )
- return results
- # def encode(self, z):
- # z = self.downsample(z)
- # z = self.pre_module(z)
- # _, indices, _, _, _ = self.quantizer(z.mT)
- # indices = rearrange(indices, "g b l r -> b (g r) l")
- # return indices
- #
- def decode(self, indices: torch.Tensor):
- # indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
- indices[:, 0] = torch.clamp(
- indices[:, 0], max=self.semantic_quantizer.codebook_size - 1
- )
- indices[:, 1:] = torch.clamp(
- indices[:, 1:], max=self.quantizer.codebook_size - 1
- )
- z_q_semantic = self.semantic_quantizer.from_codes(indices[:, :1])[0]
- z_q_residual = self.quantizer.from_codes(indices[:, 1:])[0]
- z_q = z_q_semantic + z_q_residual
- z_q = self.post_module(z_q)
- z_q = self.upsample(z_q)
- return z_q
- # def from_latents(self, latents: torch.Tensor):
- # z_q, z_p, codes = super().from_latents(latents)
- # z_q = self.upsample(z_q)
- # return z_q, z_p, codes
- if __name__ == "__main__":
- rvq = DownsampleResidualVectorQuantize(
- input_dim=512,
- n_codebooks=8,
- codebook_dim=8,
- codebook_size=1024,
- quantizer_dropout=0.5,
- downsample_factor=[2, 2],
- )
- rvq.eval()
- x = torch.randn(2, 512, 442)
- result = rvq(x)
- print(rvq)
- print(result.latents.shape, result.codes.shape, result.z.shape)
- # y = rvq.from_codes(result.codes)
- # print(y[0].shape)
- # y = rvq.from_latents(
- result1 = rvq(x[:, :, :40])
- print(result1.latents.shape, result1.codes.shape, result1.z.shape)
- assert torch.allclose(result.z[:, :, :40], result1.z, atol=1e-8)
- print("Success")
|