Kaynağa Gözat

update vq baseline

Lengyue 2 yıl önce
ebeveyn
işleme
223d0b8f81

+ 12 - 22
fish_speech/configs/vqgan_pretrain.yaml

@@ -2,7 +2,7 @@ defaults:
   - base
   - _self_
 
-project: vq_reflow_shallow_group_fsq_8x1024_wavenet
+project: vq-group-fsq-8x1024-wn-20x768-cond
 
 # Lightning Trainer
 trainer:
@@ -10,8 +10,8 @@ trainer:
   devices: 1
   precision: bf16-mixed
   max_steps: 1_000_000
-  val_check_interval: 1000
-  strategy: ddp
+  val_check_interval: 5000
+  strategy: ddp_find_unused_parameters_true
 
 sample_rate: 44100
 hop_length: 512
@@ -38,8 +38,8 @@ data:
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
-  batch_size: 128
-  val_batch_size: 4
+  batch_size: 64
+  val_batch_size: 64
 
 # Model Configuration
 model:
@@ -58,33 +58,24 @@ model:
   encoder:
     _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
     input_channels: ${num_mels}
-    residual_channels: 512
+    residual_channels: 768
     residual_layers: 20
     dilation_cycle: 4
   
   quantizer:
     _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
-    input_dim: 512
-    n_codebooks: 8
-    n_groups: 1
+    input_dim: 768
+    n_codebooks: 1
+    n_groups: 8
     levels: [8, 5, 5, 5]
-  
-  decoder:
-    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    output_channels: ${num_mels}
-    residual_channels: 512
-    residual_layers: 20
-    dilation_cycle: 4
 
-  reflow:
+  decoder:
     _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    input_channels: ${num_mels}
     output_channels: ${num_mels}
-    residual_channels: 512
-    condition_channels: 512
+    residual_channels: 768
     residual_layers: 20
     dilation_cycle: 4
-    is_diffusion: true
+    condition_channels: 768
 
   vocoder:
     _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
@@ -129,4 +120,3 @@ callbacks:
       - encoder
       - decoder
       - quantizer
-      - reflow

+ 1 - 1
fish_speech/datasets/vqgan.py

@@ -122,7 +122,7 @@ class VQGANDataModule(LightningDataModule):
     def val_dataloader(self):
         return DataLoader(
             self.val_dataset,
-            batch_size=self.batch_size,
+            batch_size=self.val_batch_size,
             collate_fn=VQGANCollator(),
             num_workers=self.num_workers,
         )

+ 59 - 109
fish_speech/models/vqgan/lit_module.py

@@ -1,6 +1,5 @@
-import itertools
-from dataclasses import dataclass
-from typing import Any, Callable, Literal, Optional
+import math
+from typing import Any, Callable
 
 import lightning as L
 import torch
@@ -22,7 +21,7 @@ class VQGAN(L.LightningModule):
         encoder: WaveNet,
         quantizer: nn.Module,
         decoder: WaveNet,
-        reflow: nn.Module,
+        # reflow: nn.Module,
         vocoder: nn.Module,
         mel_transform: nn.Module,
         weight_reflow: float = 1.0,
@@ -44,7 +43,7 @@ class VQGAN(L.LightningModule):
         self.quantizer = quantizer
         self.decoder = decoder
         self.vocoder = vocoder
-        self.reflow = reflow
+        # self.reflow = reflow
         self.mel_transform = mel_transform
 
         # Freeze vocoder
@@ -122,51 +121,21 @@ class VQGAN(L.LightningModule):
         vq_recon_features = vq_result.z * mel_masks_float_conv
 
         # VQ Decode
-        gen_mel = self.decoder(vq_recon_features) * mel_masks_float_conv
+        gen_mel = (
+            self.decoder(
+                torch.randn_like(vq_recon_features) * mel_masks_float_conv,
+                condition=vq_recon_features,
+            )
+            * mel_masks_float_conv
+        )
 
         # Mel Loss
         loss_mel = (gen_mel - gt_mels).abs().mean(
             dim=1, keepdim=True
         ).sum() / mel_masks_float_conv.sum()
 
-        # Reflow, given x_1_aux, we want to reconstruct x_1
-        x_1 = self.norm_spec(gt_mels)
-        t = torch.rand(gt_mels.shape[0], device=gt_mels.device, dtype=torch.float32)
-        t = torch.clamp(t, 1e-6, 1 - 1e-6)  # Avoid 0 and 1
-        x_0 = torch.randn_like(x_1)
-
-        # X_t = t * X_1 + (1 - t) * X_0
-        x_t = x_0 + t[:, None, None] * (x_1 - x_0)
-
-        v_pred = self.reflow(
-            x_t,
-            1000 * t,
-            vq_recon_features.detach(),  # Stop gradients, avoid reflow to destroy the VQ
-        )
-
-        # Log L2 loss with
-        with torch.autocast(device_type=gt_mels.device.type, dtype=torch.float32):
-            weights = (
-                0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
-            )
-            assert (
-                torch.isnan(weights).any() == False
-                and torch.isinf(weights).any() == False
-            ), "Found NaN or Inf in weights."
-
-            loss_reflow = weights[:, None, None] * F.mse_loss(
-                x_1 - x_0, v_pred, reduction="none"
-            )
-            loss_reflow = (loss_reflow * mel_masks_float_conv).mean(
-                dim=1
-            ).sum() / mel_masks_float_conv.sum()
-
         # Total loss
-        loss = (
-            self.weight_vq * loss_vq
-            + self.weight_mel * loss_mel
-            + self.weight_reflow * loss_reflow
-        )
+        loss = self.weight_vq * loss_vq + self.weight_mel * loss_mel
 
         # Log losses
         self.log(
@@ -193,14 +162,6 @@ class VQGAN(L.LightningModule):
             prog_bar=False,
             logger=True,
         )
-        self.log(
-            "train/generator/loss_reflow",
-            loss_reflow,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-        )
 
         return loss
 
@@ -223,7 +184,13 @@ class VQGAN(L.LightningModule):
         vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
 
         # VQ Decode
-        gen_aux_mels = self.decoder(vq_recon_features) * mel_masks_float_conv
+        gen_aux_mels = (
+            self.decoder(
+                torch.randn_like(vq_recon_features) * mel_masks_float_conv,
+                condition=vq_recon_features,
+            )
+            * mel_masks_float_conv
+        )
         loss_mel = (gen_aux_mels - gt_mels).abs().mean(
             dim=1, keepdim=True
         ).sum() / mel_masks_float_conv.sum()
@@ -238,45 +205,8 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
 
-        # Reflow inference
-        t_start = 0.0
-
-        x_1 = self.norm_spec(gen_aux_mels)
-        x_0 = torch.randn_like(x_1)
-        gen_reflow_mels = (1 - t_start) * x_0 + t_start * x_1
-
-        t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
-        dt = (1.0 - t_start) / self.reflow_inference_steps
-
-        for _ in range(self.reflow_inference_steps):
-            gen_reflow_mels += (
-                self.reflow(
-                    gen_reflow_mels,
-                    1000 * t,
-                    vq_recon_features,
-                )
-                * dt
-            )
-            t += dt
-
-        gen_reflow_mels = self.denorm_spec(gen_reflow_mels) * mel_masks_float_conv
-        loss_reflow_mel = (gen_reflow_mels - gt_mels).abs().mean(
-            dim=1, keepdim=True
-        ).sum() / mel_masks_float_conv.sum()
-
-        self.log(
-            "val/loss_reflow_mel",
-            loss_reflow_mel,
-            on_step=False,
-            on_epoch=True,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
         recon_audios = self.vocoder(gt_mels)
         gen_aux_audios = self.vocoder(gen_aux_mels)
-        gen_reflow_audios = self.vocoder(gen_reflow_mels)
 
         # only log the first batch
         if batch_idx != 0:
@@ -285,36 +215,33 @@ class VQGAN(L.LightningModule):
         for idx, (
             gt_mel,
             gen_aux_mel,
-            gen_reflow_mel,
             audio,
             gen_aux_audio,
-            gen_reflow_audio,
             recon_audio,
             audio_len,
         ) in enumerate(
             zip(
                 gt_mels,
                 gen_aux_mels,
-                gen_reflow_mels,
-                audios.float(),
-                gen_aux_audios.float(),
-                gen_reflow_audios.float(),
-                recon_audios.float(),
+                audios.cpu().float(),
+                gen_aux_audios.cpu().float(),
+                recon_audios.cpu().float(),
                 audio_lengths,
             )
         ):
+            if idx > 4:
+                break
+
             mel_len = audio_len // self.mel_transform.hop_length
 
             image_mels = plot_mel(
                 [
                     gt_mel[:, :mel_len],
                     gen_aux_mel[:, :mel_len],
-                    gen_reflow_mel[:, :mel_len],
                 ],
                 [
                     "Ground-Truth",
                     "Auxiliary",
-                    "Reflow",
                 ],
             )
 
@@ -333,11 +260,6 @@ class VQGAN(L.LightningModule):
                                 sample_rate=self.sampling_rate,
                                 caption="aux",
                             ),
-                            wandb.Audio(
-                                gen_reflow_audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="reflow",
-                            ),
                             wandb.Audio(
                                 recon_audio[0, :audio_len],
                                 sample_rate=self.sampling_rate,
@@ -365,12 +287,6 @@ class VQGAN(L.LightningModule):
                     self.global_step,
                     sample_rate=self.sampling_rate,
                 )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/reflow",
-                    gen_reflow_audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
                 self.logger.experiment.add_audio(
                     f"sample-{idx}/wavs/recon",
                     recon_audio[0, :audio_len],
@@ -379,3 +295,37 @@ class VQGAN(L.LightningModule):
                 )
 
             plt.close(image_mels)
+
+    def encode(self, audios, audio_lengths):
+        audios = audios.float()
+
+        gt_mels = self.mel_transform(audios)
+        mel_lengths = audio_lengths // self.mel_transform.hop_length
+        mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+        gt_mels = gt_mels * mel_masks_float_conv
+
+        # Encode
+        encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
+        feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
+
+        return self.quantizer.encode(encoded_features), feature_lengths
+
+    def decode(self, indices, feature_lengths, return_audios=False):
+        factor = math.prod(self.quantizer.downsample_factor)
+        mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+
+        z = self.quantizer.decode(indices) * mel_masks_float_conv
+        gen_mel = (
+            self.decoder(
+                torch.randn_like(z) * mel_masks_float_conv,
+                condition=z,
+            )
+            * mel_masks_float_conv
+        )
+
+        if return_audios:
+            return self.vocoder(gen_mel)
+
+        return gen_mel

+ 11 - 7
fish_speech/models/vqgan/modules/fsq.py

@@ -1,12 +1,9 @@
 from dataclasses import dataclass
-from typing import Union
 
-import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from einops import rearrange
-from torch.nn.utils import weight_norm
 from vector_quantize_pytorch import GroupedResidualFSQ
 
 from .firefly import ConvNeXtBlock
@@ -106,10 +103,17 @@ class DownsampleFiniteScalarQuantize(nn.Module):
 
         return result
 
-    # def from_codes(self, codes: torch.Tensor):
-    #     z_q, z_p, codes = self.residual_fsq.get_output_from_indices(codes)
-    #     z_q = self.upsample(z_q)
-    #     return z_q, z_p, codes
+    def encode(self, z):
+        z = self.downsample(z)
+        _, indices = self.residual_fsq(z.mT)
+        indices = rearrange(indices, "g b l r -> b (g r) l")
+        return indices
+
+    def decode(self, indices: torch.Tensor):
+        indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
+        z_q = self.residual_fsq.get_output_from_indices(indices)
+        z_q = self.upsample(z_q.mT)
+        return z_q
 
     # def from_latents(self, latents: torch.Tensor):
     #     z_q, z_p, codes = super().from_latents(latents)

+ 3 - 3
fish_speech/models/vqgan/modules/wavenet.py

@@ -89,7 +89,6 @@ class ResidualBlock(nn.Module):
         residual_channels,
         use_linear_bias=False,
         dilation=1,
-        has_condition=True,
         condition_channels=None,
     ):
         super(ResidualBlock, self).__init__()
@@ -102,7 +101,7 @@ class ResidualBlock(nn.Module):
             dilation=dilation,
         )
 
-        if has_condition:
+        if condition_channels is not None:
             self.diffusion_projection = LinearNorm(
                 residual_channels, residual_channels, use_linear_bias
             )
@@ -159,6 +158,8 @@ class WaveNet(nn.Module):
         if input_channels is None:
             input_channels = residual_channels
 
+        self.input_channels = input_channels
+
         # Residual layers
         self.residual_layers = nn.ModuleList(
             [
@@ -166,7 +167,6 @@ class WaveNet(nn.Module):
                     residual_channels=residual_channels,
                     use_linear_bias=False,
                     dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
-                    has_condition=is_diffusion,
                     condition_channels=condition_channels,
                 )
                 for i in range(residual_layers)

+ 2 - 5
tools/vqgan/extract_vq.py

@@ -11,14 +11,12 @@ import click
 import numpy as np
 import torch
 import torchaudio
-from einops import rearrange
 from hydra import compose, initialize
 from hydra.utils import instantiate
 from lightning import LightningModule
 from loguru import logger
 from omegaconf import OmegaConf
 
-from fish_speech.models.vqgan.utils import sequence_mask
 from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
 
 # register eval resolver
@@ -57,7 +55,7 @@ def get_model(
     if "state_dict" in state_dict:
         state_dict = state_dict["state_dict"]
 
-    model.load_state_dict(state_dict, strict=True)
+    model.load_state_dict(state_dict, strict=False)
     model.eval()
     model.cuda()
 
@@ -90,8 +88,7 @@ def process_batch(files: list[Path], model) -> float:
 
     # Calculate lengths
     with torch.no_grad():
-        out = model.encode(audios, audio_lengths)
-        indices, feature_lengths = out.indices, out.feature_lengths
+        indices, feature_lengths = model.encode(audios, audio_lengths)
 
     # Save to disk
     outputs = indices.cpu().numpy()

+ 13 - 7
tools/vqgan/inference.py

@@ -26,14 +26,18 @@ OmegaConf.register_new_resolver("eval", eval)
 @click.option(
     "--input-path",
     "-i",
-    default="data/Genshin/Chinese/派蒙/vo_WYLQ103_10_paimon_04.wav",
+    default="data/sft/Rail_ZH/三月七/1fe0cc6fc3fe3e6d.wav",
     type=click.Path(exists=True, path_type=Path),
 )
 @click.option(
     "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
 )
 @click.option("--config-name", "-cfg", default="vqgan_pretrain")
-@click.option("--checkpoint-path", "-ckpt", default="checkpoints/vqgan-v1.pth")
+@click.option(
+    "--checkpoint-path",
+    "-ckpt",
+    default="results/vq-group-fsq-8x1024-wn-20x512-cond-e009/checkpoints/step_000355000.ckpt",
+)
 def main(input_path, output_path, config_name, checkpoint_path):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
         cfg = compose(config_name=config_name)
@@ -45,7 +49,7 @@ def main(input_path, output_path, config_name, checkpoint_path):
     )
     if "state_dict" in state_dict:
         state_dict = state_dict["state_dict"]
-    model.load_state_dict(state_dict, strict=True)
+    model.load_state_dict(state_dict, strict=False)
     model.eval()
     model.cuda()
     logger.info("Restored model from checkpoint")
@@ -67,8 +71,7 @@ def main(input_path, output_path, config_name, checkpoint_path):
         audio_lengths = torch.tensor(
             [audios.shape[2]], device=model.device, dtype=torch.long
         )
-        encoded = model.encode(audios, audio_lengths)
-        indices = encoded.indices[0]
+        indices = model.encode(audios, audio_lengths)[0][0]
 
         logger.info(f"Generated indices of shape {indices.shape}")
 
@@ -82,12 +85,15 @@ def main(input_path, output_path, config_name, checkpoint_path):
     else:
         raise ValueError(f"Unknown input type: {input_path}")
 
+    # random destroy 10% of indices
+    # mask = torch.rand_like(indices, dtype=torch.float) > 0.9
+    # indices[mask] = torch.randint(0, 1000, mask.shape, device=indices.device, dtype=indices.dtype)[mask]
+
     # Restore
     feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
-    decoded = model.decode(
+    fake_audios = model.decode(
         indices=indices[None], feature_lengths=feature_lengths, return_audios=True
     )
-    fake_audios = decoded.audios
     audio_time = fake_audios.shape[-1] / model.sampling_rate
 
     logger.info(