Procházet zdrojové kódy

Optimize vqgan & sft

Lengyue před 2 roky
rodič
revize
158b0c82c0

+ 13 - 6
fish_speech/configs/text2semantic_sft_medium.yaml

@@ -2,9 +2,11 @@ defaults:
   - base
   - base
   - _self_
   - _self_
 
 
-project: text2semantic_sft_medium
+project: text2semantic_sft_medium_delay
 max_length: 4096
 max_length: 4096
-use_delay_pattern: true
+use_delay_pattern: false
+ckpt_path: results/text2semantic_pretrain_medium_4_in_8_codebooks/checkpoints/step_000100000.ckpt
+resume_weights_only: true
 
 
 # Lightning Trainer
 # Lightning Trainer
 trainer:
 trainer:
@@ -14,7 +16,7 @@ trainer:
   max_steps: 10_000
   max_steps: 10_000
   precision: bf16-true
   precision: bf16-true
   limit_val_batches: 10
   limit_val_batches: 10
-  val_check_interval: 1000
+  val_check_interval: 500
 
 
 # Dataset Configuration
 # Dataset Configuration
 tokenizer:
 tokenizer:
@@ -54,7 +56,7 @@ data:
   train_dataset: ${train_dataset}
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
   num_workers: 4
-  batch_size: 32
+  batch_size: 16
   tokenizer: ${tokenizer}
   tokenizer: ${tokenizer}
   max_length: ${max_length}
   max_length: ${max_length}
 
 
@@ -83,7 +85,7 @@ model:
   optimizer:
   optimizer:
     _target_: bitsandbytes.optim.AdamW8bit
     _target_: bitsandbytes.optim.AdamW8bit
     _partial_: true
     _partial_: true
-    lr: 1e-4
+    lr: 4e-5
     weight_decay: 0
     weight_decay: 0
     betas: [0.9, 0.95]
     betas: [0.9, 0.95]
     eps: 1e-5
     eps: 1e-5
@@ -94,6 +96,11 @@ model:
     lr_lambda:
     lr_lambda:
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _partial_: true
       _partial_: true
-      num_warmup_steps: 200
+      num_warmup_steps: 100
       num_training_steps: ${trainer.max_steps}
       num_training_steps: ${trainer.max_steps}
       final_lr_ratio: 0
       final_lr_ratio: 0
+
+callbacks:
+  model_checkpoint:
+    every_n_train_steps: 1000
+    save_top_k: 10

+ 13 - 9
fish_speech/configs/vqgan_pretrain.yaml

@@ -2,7 +2,7 @@ defaults:
   - base
   - base
   - _self_
   - _self_
 
 
-project: vqgan_pretrain
+project: vqgan_pretrain_lfq
 ckpt_path: checkpoints/gpt_sovits_488k.pth
 ckpt_path: checkpoints/gpt_sovits_488k.pth
 resume_weights_only: true
 resume_weights_only: true
 
 
@@ -13,7 +13,7 @@ trainer:
   strategy: ddp_find_unused_parameters_true
   strategy: ddp_find_unused_parameters_true
   precision: 32
   precision: 32
   max_steps: 1_000_000
   max_steps: 1_000_000
-  val_check_interval: 5000
+  val_check_interval: 2000
 
 
 sample_rate: 32000
 sample_rate: 32000
 hop_length: 640
 hop_length: 640
@@ -48,22 +48,23 @@ model:
   _target_: fish_speech.models.vqgan.VQGAN
   _target_: fish_speech.models.vqgan.VQGAN
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   hop_length: ${hop_length}
-  freeze_discriminator: true
+  freeze_discriminator: false
 
 
-  weight_mel: 45
+  weight_mel: 45.0
   weight_kl: 0.1
   weight_kl: 0.1
   weight_vq: 1.0
   weight_vq: 1.0
+  weight_aux_mel: 1.0
 
 
   generator:
   generator:
     _target_: fish_speech.models.vqgan.modules.models.SynthesizerTrn
     _target_: fish_speech.models.vqgan.modules.models.SynthesizerTrn
     spec_channels: 1025
     spec_channels: 1025
     segment_size: 32
     segment_size: 32
     inter_channels: 192
     inter_channels: 192
-    hidden_channels: 192
-    filter_channels: 768
-    n_heads: 2
-    n_layers: 6
-    kernel_size: 3
+    prior_hidden_channels: 192
+    posterior_hidden_channels: 192
+    prior_n_layers: 16
+    posterior_n_layers: 16
+    kernel_size: 5
     p_dropout: 0.1
     p_dropout: 0.1
     resblock: "1"
     resblock: "1"
     resblock_kernel_sizes: [3, 7, 11]
     resblock_kernel_sizes: [3, 7, 11]
@@ -73,8 +74,11 @@ model:
     upsample_kernel_sizes: [16, 16, 8, 2, 2]
     upsample_kernel_sizes: [16, 16, 8, 2, 2]
     gin_channels: 512
     gin_channels: 512
     freeze_quantizer: false
     freeze_quantizer: false
+    freeze_decoder: false
+    freeze_posterior_encoder: false
     codebook_size: 1024
     codebook_size: 1024
     num_codebooks: 2
     num_codebooks: 2
+    aux_spec_channels: ${num_mels}
 
 
   discriminator:
   discriminator:
     _target_: fish_speech.models.vqgan.modules.models.EnsembledDiscriminator
     _target_: fish_speech.models.vqgan.modules.models.EnsembledDiscriminator

+ 21 - 2
fish_speech/models/vqgan/lit_module.py

@@ -50,6 +50,7 @@ class VQGAN(L.LightningModule):
         weight_mel: float = 45,
         weight_mel: float = 45,
         weight_kl: float = 0.1,
         weight_kl: float = 0.1,
         weight_vq: float = 1.0,
         weight_vq: float = 1.0,
+        weight_aux_mel: float = 20.0,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
@@ -68,6 +69,7 @@ class VQGAN(L.LightningModule):
         self.weight_mel = weight_mel
         self.weight_mel = weight_mel
         self.weight_kl = weight_kl
         self.weight_kl = weight_kl
         self.weight_vq = weight_vq
         self.weight_vq = weight_vq
+        self.weight_aux_mel = weight_aux_mel
 
 
         # Other parameters
         # Other parameters
         self.hop_length = hop_length
         self.hop_length = hop_length
@@ -131,10 +133,14 @@ class VQGAN(L.LightningModule):
             y_mask,
             y_mask,
             y_mask,
             y_mask,
             (z, z_p, m_p, logs_p, m_q, logs_q),
             (z, z_p, m_p, logs_p, m_q, logs_q),
-            quantized,
+            loss_vq,
+            decoded_aux_mels,
         ) = self.generator(gt_specs, spec_lengths)
         ) = self.generator(gt_specs, spec_lengths)
 
 
         gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
         gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
+        decoded_aux_mels = slice_segments(
+            decoded_aux_mels, ids_slice, self.generator.segment_size
+        )
         spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
         spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
         audios = slice_segments(
         audios = slice_segments(
             audios,
             audios,
@@ -205,6 +211,9 @@ class VQGAN(L.LightningModule):
 
 
         with torch.autocast(device_type=audios.device.type, enabled=False):
         with torch.autocast(device_type=audios.device.type, enabled=False):
             loss_mel = F.l1_loss(gt_mels * spec_masks, fake_mels * spec_masks)
             loss_mel = F.l1_loss(gt_mels * spec_masks, fake_mels * spec_masks)
+            loss_aux_mel = F.l1_loss(
+                gt_mels * spec_masks, decoded_aux_mels * spec_masks
+            )
 
 
         self.log(
         self.log(
             "train/generator/loss_mel",
             "train/generator/loss_mel",
@@ -216,7 +225,16 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
             sync_dist=True,
         )
         )
 
 
-        loss_vq = quantized.commitment_loss + quantized.codebook_loss
+        self.log(
+            "train/generator/loss_aux_mel",
+            loss_aux_mel,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+
         self.log(
         self.log(
             "train/generator/loss_vq",
             "train/generator/loss_vq",
             loss_vq,
             loss_vq,
@@ -241,6 +259,7 @@ class VQGAN(L.LightningModule):
 
 
         loss = (
         loss = (
             loss_mel * self.weight_mel
             loss_mel * self.weight_mel
+            + loss_aux_mel * self.weight_aux_mel
             + loss_vq * self.weight_vq
             + loss_vq * self.weight_vq
             + loss_kl * self.weight_kl
             + loss_kl * self.weight_kl
             + loss_adv
             + loss_adv

+ 92 - 67
fish_speech/models/vqgan/modules/models.py

@@ -19,54 +19,58 @@ class FeatureEncoder(nn.Module):
         spec_channels,
         spec_channels,
         out_channels,
         out_channels,
         hidden_channels,
         hidden_channels,
-        filter_channels,
-        n_heads,
         n_layers,
         n_layers,
         kernel_size,
         kernel_size,
         p_dropout,
         p_dropout,
         codebook_size=1024,
         codebook_size=1024,
         num_codebooks=2,
         num_codebooks=2,
         gin_channels=0,
         gin_channels=0,
+        aux_spec_channels=None,
     ):
     ):
         super().__init__()
         super().__init__()
         self.out_channels = out_channels
         self.out_channels = out_channels
         self.hidden_channels = hidden_channels
         self.hidden_channels = hidden_channels
-        self.filter_channels = filter_channels
-        self.n_heads = n_heads
         self.n_layers = n_layers
         self.n_layers = n_layers
         self.kernel_size = kernel_size
         self.kernel_size = kernel_size
         self.p_dropout = p_dropout
         self.p_dropout = p_dropout
 
 
+        if aux_spec_channels is None:
+            aux_spec_channels = spec_channels
+
         self.spec_proj = nn.Conv1d(spec_channels, hidden_channels, 1)
         self.spec_proj = nn.Conv1d(spec_channels, hidden_channels, 1)
 
 
-        self.encoder = attentions.Encoder(
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers // 2,
-            kernel_size,
-            p_dropout,
+        self.encoder = modules.WN(
+            hidden_channels=hidden_channels,
+            kernel_size=kernel_size,
+            dilation_rate=1,
+            n_layers=n_layers // 2,
         )
         )
 
 
         self.vq = DownsampleResidualVectorQuantizer(
         self.vq = DownsampleResidualVectorQuantizer(
             input_dim=hidden_channels,
             input_dim=hidden_channels,
             n_codebooks=num_codebooks,
             n_codebooks=num_codebooks,
             codebook_size=codebook_size,
             codebook_size=codebook_size,
+            codebook_dim=hidden_channels,
             min_quantizers=num_codebooks,
             min_quantizers=num_codebooks,
             downsample_factor=(2,),
             downsample_factor=(2,),
         )
         )
 
 
-        self.decoder = attentions.Encoder(
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers // 2,
-            kernel_size,
-            p_dropout,
-            isflow=True,
+        self.decoder = modules.WN(
+            hidden_channels=hidden_channels,
+            kernel_size=kernel_size,
+            dilation_rate=1,
+            n_layers=n_layers // 2,
             gin_channels=gin_channels,
             gin_channels=gin_channels,
         )
         )
 
 
+        self.aux_decoder = modules.WN(
+            hidden_channels=hidden_channels,
+            kernel_size=kernel_size,
+            dilation_rate=1,
+            n_layers=4,
+        )
+        self.aux_proj = nn.Conv1d(hidden_channels, aux_spec_channels, 1)
+
         self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
         self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
 
 
     def forward(self, y, y_lengths, ge):
     def forward(self, y, y_lengths, ge):
@@ -75,13 +79,15 @@ class FeatureEncoder(nn.Module):
         )
         )
 
 
         y = self.spec_proj(y * y_mask) * y_mask
         y = self.spec_proj(y * y_mask) * y_mask
-        y = self.encoder(y * y_mask, y_mask)
-        quantized = self.vq(y)
-        y = self.decoder(quantized.z * y_mask, y_mask, g=ge)
+        y = self.encoder(y, y_mask) * y_mask
+        z, indices, loss_vq = self.vq(y)
+        y = self.decoder(z, y_mask, g=ge) * y_mask
+        decoded_aux_mel = self.aux_decoder(y, y_mask)
+        decoded_aux_mel = self.aux_proj(decoded_aux_mel) * y_mask
 
 
         stats = self.proj(y) * y_mask
         stats = self.proj(y) * y_mask
         m, logs = torch.split(stats, self.out_channels, dim=1)
         m, logs = torch.split(stats, self.out_channels, dim=1)
-        return y, m, logs, y_mask, quantized
+        return y, m, logs, y_mask, loss_vq, decoded_aux_mel
 
 
 
 
 class ResidualCouplingBlock(nn.Module):
 class ResidualCouplingBlock(nn.Module):
@@ -436,10 +442,10 @@ class SynthesizerTrn(nn.Module):
         spec_channels,
         spec_channels,
         segment_size,
         segment_size,
         inter_channels,
         inter_channels,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
+        prior_hidden_channels,
+        prior_n_layers,
+        posterior_hidden_channels,
+        posterior_n_layers,
         kernel_size,
         kernel_size,
         p_dropout,
         p_dropout,
         resblock,
         resblock,
@@ -452,14 +458,17 @@ class SynthesizerTrn(nn.Module):
         freeze_quantizer=False,
         freeze_quantizer=False,
         codebook_size=1024,
         codebook_size=1024,
         num_codebooks=2,
         num_codebooks=2,
+        freeze_decoder=False,
+        freeze_posterior_encoder=False,
+        aux_spec_channels=None,
     ):
     ):
         super().__init__()
         super().__init__()
         self.spec_channels = spec_channels
         self.spec_channels = spec_channels
         self.inter_channels = inter_channels
         self.inter_channels = inter_channels
-        self.hidden_channels = hidden_channels
-        self.filter_channels = filter_channels
-        self.n_heads = n_heads
-        self.n_layers = n_layers
+        self.prior_hidden_channels = prior_hidden_channels
+        self.prior_n_layers = prior_n_layers
+        self.posterior_hidden_channels = posterior_hidden_channels
+        self.posterior_n_layers = posterior_n_layers
         self.kernel_size = kernel_size
         self.kernel_size = kernel_size
         self.p_dropout = p_dropout
         self.p_dropout = p_dropout
         self.resblock = resblock
         self.resblock = resblock
@@ -472,40 +481,44 @@ class SynthesizerTrn(nn.Module):
         self.gin_channels = gin_channels
         self.gin_channels = gin_channels
 
 
         self.enc_p = FeatureEncoder(
         self.enc_p = FeatureEncoder(
-            spec_channels,
-            inter_channels,
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers,
-            kernel_size,
-            p_dropout,
+            spec_channels=spec_channels,
+            out_channels=inter_channels,
+            hidden_channels=prior_hidden_channels,
+            n_layers=prior_n_layers,
+            kernel_size=kernel_size,
+            p_dropout=p_dropout,
             codebook_size=codebook_size,
             codebook_size=codebook_size,
             num_codebooks=num_codebooks,
             num_codebooks=num_codebooks,
             gin_channels=gin_channels,
             gin_channels=gin_channels,
+            aux_spec_channels=aux_spec_channels,
         )
         )
         self.dec = Generator(
         self.dec = Generator(
-            inter_channels,
-            resblock,
-            resblock_kernel_sizes,
-            resblock_dilation_sizes,
-            upsample_rates,
-            upsample_initial_channel,
-            upsample_kernel_sizes,
+            initial_channel=inter_channels,
+            resblock=resblock,
+            resblock_kernel_sizes=resblock_kernel_sizes,
+            resblock_dilation_sizes=resblock_dilation_sizes,
+            upsample_rates=upsample_rates,
+            upsample_initial_channel=upsample_initial_channel,
+            upsample_kernel_sizes=upsample_kernel_sizes,
             gin_channels=gin_channels,
             gin_channels=gin_channels,
         )
         )
         self.enc_q = PosteriorEncoder(
         self.enc_q = PosteriorEncoder(
-            spec_channels,
+            in_channels=spec_channels,
+            out_channels=inter_channels,
+            hidden_channels=posterior_hidden_channels,
+            kernel_size=5,
+            dilation_rate=1,
+            n_layers=posterior_n_layers,
+            gin_channels=gin_channels,
+        )
+        self.flow = ResidualCouplingBlock(
             inter_channels,
             inter_channels,
-            hidden_channels,
+            posterior_hidden_channels,
             5,
             5,
             1,
             1,
-            16,
+            4,
             gin_channels=gin_channels,
             gin_channels=gin_channels,
         )
         )
-        self.flow = ResidualCouplingBlock(
-            inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
-        )
 
 
         self.ref_enc = modules.MelStyleEncoder(
         self.ref_enc = modules.MelStyleEncoder(
             spec_channels, style_vector_dim=gin_channels
             spec_channels, style_vector_dim=gin_channels
@@ -516,13 +529,21 @@ class SynthesizerTrn(nn.Module):
             self.enc_p.encoder.requires_grad_(False)
             self.enc_p.encoder.requires_grad_(False)
             self.enc_p.vq.requires_grad_(False)
             self.enc_p.vq.requires_grad_(False)
 
 
+        if freeze_decoder:
+            self.dec.requires_grad_(False)
+
+        if freeze_posterior_encoder:
+            self.enc_q.requires_grad_(False)
+
     def forward(self, y, y_lengths):
     def forward(self, y, y_lengths):
         y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
         y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
             y.dtype
             y.dtype
         )
         )
         ge = self.ref_enc(y * y_mask, y_mask)
         ge = self.ref_enc(y * y_mask, y_mask)
 
 
-        x, m_p, logs_p, y_mask, quantized = self.enc_p(y, y_lengths, ge)
+        x, m_p, logs_p, y_mask, quantized, decoded_aux_mel = self.enc_p(
+            y, y_lengths, ge
+        )
         z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
         z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
         z_p = self.flow(z, y_mask, g=ge)
         z_p = self.flow(z, y_mask, g=ge)
 
 
@@ -538,6 +559,7 @@ class SynthesizerTrn(nn.Module):
             y_mask,
             y_mask,
             (z, z_p, m_p, logs_p, m_q, logs_q),
             (z, z_p, m_p, logs_p, m_q, logs_q),
             quantized,
             quantized,
+            decoded_aux_mel,
         )
         )
 
 
     def infer(self, y, y_lengths, noise_scale=0.5):
     def infer(self, y, y_lengths, noise_scale=0.5):
@@ -545,7 +567,7 @@ class SynthesizerTrn(nn.Module):
             y.dtype
             y.dtype
         )
         )
         ge = self.ref_enc(y * y_mask, y_mask)
         ge = self.ref_enc(y * y_mask, y_mask)
-        x, m_p, logs_p, y_mask, quantized = self.enc_p(y, y_lengths, ge)
+        x, m_p, logs_p, y_mask, _, _ = self.enc_p(y, y_lengths, ge)
         z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
         z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
 
 
         z = self.flow(z_p, y_mask, g=ge, reverse=True)
         z = self.flow(z_p, y_mask, g=ge, reverse=True)
@@ -600,10 +622,10 @@ if __name__ == "__main__":
         spec_channels=1025,
         spec_channels=1025,
         segment_size=20480,
         segment_size=20480,
         inter_channels=192,
         inter_channels=192,
-        hidden_channels=192,
-        filter_channels=768,
-        n_heads=2,
-        n_layers=6,
+        prior_hidden_channels=384,
+        posterior_hidden_channels=192,
+        prior_n_layers=16,
+        posterior_n_layers=16,
         kernel_size=3,
         kernel_size=3,
         p_dropout=0.1,
         p_dropout=0.1,
         resblock="1",
         resblock="1",
@@ -617,18 +639,21 @@ if __name__ == "__main__":
     )
     )
 
 
     state_dict_g = torch.load("checkpoints/gpt_sovits_g_488k.pth", map_location="cpu")
     state_dict_g = torch.load("checkpoints/gpt_sovits_g_488k.pth", map_location="cpu")
-    # state_dict_d = torch.load("checkpoints/gpt_sovits_d_488k.pth", map_location="cpu")
-    # keys = set(model.state_dict().keys())
-    # state_dict_g = {k.replace("encoder2.", "decoder."): v for k, v in state_dict_g.items() if k in keys}
+    state_dict_d = torch.load("checkpoints/gpt_sovits_d_488k.pth", map_location="cpu")
+    keys = set(model.state_dict().keys())
+    state_dict_g = {
+        k: v for k, v in state_dict_g.items() if k in keys and "enc_p" not in k
+    }
 
 
-    # new_state = {}
-    # for k, v in state_dict_g.items():
-    #     new_state["generator." + k] = v
+    new_state = {}
+    for k, v in state_dict_g.items():
+        new_state["generator." + k] = v
 
 
-    # for k, v in state_dict_d.items():
-    #     new_state["discriminator." + k] = v
+    for k, v in state_dict_d.items():
+        new_state["discriminator." + k] = v
 
 
-    # torch.save(new_state, "checkpoints/gpt_sovits_488k.pth")
+    torch.save(new_state, "checkpoints/gpt_sovits_488k.pth")
+    exit()
 
 
     # print(EnsembledDiscriminator().load_state_dict(state_dict_d, strict=False))
     # print(EnsembledDiscriminator().load_state_dict(state_dict_d, strict=False))
     print(model.load_state_dict(state_dict_g, strict=False))
     print(model.load_state_dict(state_dict_g, strict=False))

+ 33 - 267
fish_speech/models/vqgan/modules/rvq.py

@@ -7,253 +7,10 @@ import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 from einops import rearrange
 from einops import rearrange
 from torch.nn.utils import weight_norm
 from torch.nn.utils import weight_norm
+from vector_quantize_pytorch import LFQ, ResidualVQ
 
 
 
 
-class VectorQuantize(nn.Module):
-    """
-    Implementation of VQ similar to Karpathy's repo:
-    https://github.com/karpathy/deep-vector-quantization
-    Additionally uses following tricks from Improved VQGAN
-    (https://arxiv.org/pdf/2110.04627.pdf):
-        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
-            for improved codebook usage
-        2. l2-normalized codes: Converts euclidean distance to cosine similarity which
-            improves training stability
-    """
-
-    def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
-        super().__init__()
-        self.codebook_size = codebook_size
-        self.codebook_dim = codebook_dim
-
-        self.in_proj = weight_norm(nn.Conv1d(input_dim, codebook_dim, kernel_size=1))
-        self.out_proj = weight_norm(nn.Conv1d(codebook_dim, input_dim, kernel_size=1))
-        self.codebook = nn.Embedding(codebook_size, codebook_dim)
-
-    def forward(self, z):
-        """Quantized the input tensor using a fixed codebook and returns
-        the corresponding codebook vectors
-
-        Parameters
-        ----------
-        z : Tensor[B x D x T]
-
-        Returns
-        -------
-        Tensor[B x D x T]
-            Quantized continuous representation of input
-        Tensor[1]
-            Commitment loss to train encoder to predict vectors closer to codebook
-            entries
-        Tensor[1]
-            Codebook loss to update the codebook
-        Tensor[B x T]
-            Codebook indices (quantized discrete representation of input)
-        Tensor[B x D x T]
-            Projected latents (continuous representation of input before quantization)
-        """
-
-        # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
-        z_e = self.in_proj(z)  # z_e : (B x D x T)
-        z_q, indices = self.decode_latents(z_e)
-
-        commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
-        codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
-
-        z_q = (
-            z_e + (z_q - z_e).detach()
-        )  # noop in forward pass, straight-through gradient estimator in backward pass
-
-        z_q = self.out_proj(z_q)
-
-        return z_q, commitment_loss, codebook_loss, indices, z_e
-
-    def embed_code(self, embed_id):
-        return F.embedding(embed_id, self.codebook.weight)
-
-    def decode_code(self, embed_id):
-        return self.embed_code(embed_id).transpose(1, 2)
-
-    def decode_latents(self, latents):
-        encodings = rearrange(latents, "b d t -> (b t) d")
-        codebook = self.codebook.weight  # codebook: (N x D)
-
-        # L2 normalize encodings and codebook (ViT-VQGAN)
-        encodings = F.normalize(encodings)
-        codebook = F.normalize(codebook)
-
-        # Compute euclidean distance with codebook
-        dist = (
-            encodings.pow(2).sum(1, keepdim=True)
-            - 2 * encodings @ codebook.t()
-            + codebook.pow(2).sum(1, keepdim=True).t()
-        )
-        indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
-        z_q = self.decode_code(indices)
-        return z_q, indices
-
-
-@dataclass
-class VQResult:
-    z: torch.Tensor
-    codes: torch.Tensor
-    latents: torch.Tensor
-    commitment_loss: torch.Tensor
-    codebook_loss: torch.Tensor
-
-
-class ResidualVectorQuantize(nn.Module):
-    """
-    Introduced in SoundStream: An end2end neural audio codec
-    https://arxiv.org/abs/2107.03312
-    """
-
-    def __init__(
-        self,
-        input_dim: int = 512,
-        n_codebooks: int = 9,
-        codebook_size: int = 1024,
-        codebook_dim: Union[int, list] = 8,
-        quantizer_dropout: float = 0.0,
-        min_quantizers: int = 4,
-    ):
-        super().__init__()
-        if isinstance(codebook_dim, int):
-            codebook_dim = [codebook_dim for _ in range(n_codebooks)]
-
-        self.n_codebooks = n_codebooks
-        self.codebook_dim = codebook_dim
-        self.codebook_size = codebook_size
-
-        self.quantizers = nn.ModuleList(
-            [
-                VectorQuantize(input_dim, codebook_size, codebook_dim[i])
-                for i in range(n_codebooks)
-            ]
-        )
-        self.quantizer_dropout = quantizer_dropout
-        self.min_quantizers = min_quantizers
-
-    def forward(self, z, n_quantizers: int = None) -> VQResult:
-        """Quantized the input tensor using a fixed set of `n` codebooks and returns
-        the corresponding codebook vectors
-        Parameters
-        ----------
-        z : Tensor[B x D x T]
-        n_quantizers : int, optional
-            No. of quantizers to use
-            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
-            Note: if `self.quantizer_dropout` is True, this argument is ignored
-                when in training mode, and a random number of quantizers is used.
-        Returns
-        -------
-        """
-        z_q = 0
-        residual = z
-        commitment_loss = 0
-        codebook_loss = 0
-
-        codebook_indices = []
-        latents = []
-
-        if n_quantizers is None:
-            n_quantizers = self.n_codebooks
-
-        if self.training:
-            n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
-            dropout = torch.randint(
-                self.min_quantizers, self.n_codebooks + 1, (z.shape[0],)
-            )
-            n_dropout = int(z.shape[0] * self.quantizer_dropout)
-            n_quantizers[:n_dropout] = dropout[:n_dropout]
-            n_quantizers = n_quantizers.to(z.device)
-
-        for i, quantizer in enumerate(self.quantizers):
-            if self.training is False and i >= n_quantizers:
-                break
-
-            z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
-                residual
-            )
-
-            # Create mask to apply quantizer dropout
-            mask = (
-                torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
-            )
-            z_q = z_q + z_q_i * mask[:, None, None]
-            residual = residual - z_q_i
-
-            # Sum losses
-            commitment_loss += (commitment_loss_i * mask).mean()
-            codebook_loss += (codebook_loss_i * mask).mean()
-
-            codebook_indices.append(indices_i)
-            latents.append(z_e_i)
-
-        codes = torch.stack(codebook_indices, dim=1)
-        latents = torch.cat(latents, dim=1)
-
-        return VQResult(z_q, codes, latents, commitment_loss, codebook_loss)
-
-    def from_codes(self, codes: torch.Tensor):
-        """Given the quantized codes, reconstruct the continuous representation
-        Parameters
-        ----------
-        codes : Tensor[B x N x T]
-            Quantized discrete representation of input
-        Returns
-        -------
-        Tensor[B x D x T]
-            Quantized continuous representation of input
-        """
-        z_q = 0.0
-        z_p = []
-        n_codebooks = codes.shape[1]
-        for i in range(n_codebooks):
-            z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
-            z_p.append(z_p_i)
-
-            z_q_i = self.quantizers[i].out_proj(z_p_i)
-            z_q = z_q + z_q_i
-        return z_q, torch.cat(z_p, dim=1), codes
-
-    def from_latents(self, latents: torch.Tensor):
-        """Given the unquantized latents, reconstruct the
-        continuous representation after quantization.
-
-        Parameters
-        ----------
-        latents : Tensor[B x N x T]
-            Continuous representation of input after projection
-
-        Returns
-        -------
-        Tensor[B x D x T]
-            Quantized representation of full-projected space
-        Tensor[B x D x T]
-            Quantized representation of latent space
-        """
-        z_q = 0
-        z_p = []
-        codes = []
-        dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
-
-        n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
-            0
-        ]
-        for i in range(n_codebooks):
-            j, k = dims[i], dims[i + 1]
-            z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
-            z_p.append(z_p_i)
-            codes.append(codes_i)
-
-            z_q_i = self.quantizers[i].out_proj(z_p_i)
-            z_q = z_q + z_q_i
-
-        return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
-
-
-class DownsampleResidualVectorQuantizer(ResidualVectorQuantize):
+class DownsampleResidualVectorQuantizer(nn.Module):
     """
     """
     Downsampled version of ResidualVectorQuantize
     Downsampled version of ResidualVectorQuantize
     """
     """
@@ -269,18 +26,26 @@ class DownsampleResidualVectorQuantizer(ResidualVectorQuantize):
         downsample_factor: tuple[int] = (2, 2),
         downsample_factor: tuple[int] = (2, 2),
         downsample_dims: tuple[int] | None = None,
         downsample_dims: tuple[int] | None = None,
     ):
     ):
+        super().__init__()
         if downsample_dims is None:
         if downsample_dims is None:
             downsample_dims = [input_dim for _ in range(len(downsample_factor))]
             downsample_dims = [input_dim for _ in range(len(downsample_factor))]
 
 
         all_dims = (input_dim,) + tuple(downsample_dims)
         all_dims = (input_dim,) + tuple(downsample_dims)
 
 
-        super().__init__(
-            all_dims[-1],
-            n_codebooks,
-            codebook_size,
-            codebook_dim,
-            quantizer_dropout,
-            min_quantizers,
+        # self.vq = ResidualVQ(
+        #     dim=all_dims[-1],
+        #     num_quantizers=n_codebooks,
+        #     codebook_dim=codebook_dim,
+        #     threshold_ema_dead_code=2,
+        #     codebook_size=codebook_size,
+        #     kmeans_init=False,
+        # )
+
+        self.vq = LFQ(
+            dim=all_dims[-1],
+            codebook_size=2**14,
+            entropy_loss_weight=0.1,
+            diversity_gamma=1.0,
         )
         )
 
 
         self.downsample_factor = downsample_factor
         self.downsample_factor = downsample_factor
@@ -310,33 +75,34 @@ class DownsampleResidualVectorQuantizer(ResidualVectorQuantize):
             ]
             ]
         )
         )
 
 
-    def forward(self, z, n_quantizers: int = None) -> VQResult:
+    def forward(self, z):
         original_shape = z.shape
         original_shape = z.shape
         z = self.downsample(z)
         z = self.downsample(z)
-        result = super().forward(z, n_quantizers)
-        result.z = self.upsample(result.z)
+        z, indices, loss = self.vq(z.mT)
+        z = self.upsample(z.mT)
+        loss = loss.mean()
 
 
         # Pad or crop z to match original shape
         # Pad or crop z to match original shape
-        diff = original_shape[-1] - result.z.shape[-1]
+        diff = original_shape[-1] - z.shape[-1]
         left = diff // 2
         left = diff // 2
         right = diff - left
         right = diff - left
 
 
         if diff > 0:
         if diff > 0:
-            result.z = F.pad(result.z, (left, right))
+            z = F.pad(z, (left, right))
         elif diff < 0:
         elif diff < 0:
-            result.z = result.z[..., left:-right]
+            z = z[..., left:-right]
 
 
-        return result
+        return z, indices, loss
 
 
-    def from_codes(self, codes: torch.Tensor):
-        z_q, z_p, codes = super().from_codes(codes)
-        z_q = self.upsample(z_q)
-        return z_q, z_p, codes
+    # def from_codes(self, codes: torch.Tensor):
+    #     z_q, z_p, codes = super().from_codes(codes)
+    #     z_q = self.upsample(z_q)
+    #     return z_q, z_p, codes
 
 
-    def from_latents(self, latents: torch.Tensor):
-        z_q, z_p, codes = super().from_latents(latents)
-        z_q = self.upsample(z_q)
-        return z_q, z_p, codes
+    # def from_latents(self, latents: torch.Tensor):
+    #     z_q, z_p, codes = super().from_latents(latents)
+    #     z_q = self.upsample(z_q)
+    #     return z_q, z_p, codes
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":