瀏覽代碼

update vq baseline

Lengyue 2 年之前
父節點
當前提交
223d0b8f81

+ 12 - 22
fish_speech/configs/vqgan_pretrain.yaml

@@ -2,7 +2,7 @@ defaults:
   - base
   - base
   - _self_
   - _self_
 
 
-project: vq_reflow_shallow_group_fsq_8x1024_wavenet
+project: vq-group-fsq-8x1024-wn-20x768-cond
 
 
 # Lightning Trainer
 # Lightning Trainer
 trainer:
 trainer:
@@ -10,8 +10,8 @@ trainer:
   devices: 1
   devices: 1
   precision: bf16-mixed
   precision: bf16-mixed
   max_steps: 1_000_000
   max_steps: 1_000_000
-  val_check_interval: 1000
-  strategy: ddp
+  val_check_interval: 5000
+  strategy: ddp_find_unused_parameters_true
 
 
 sample_rate: 44100
 sample_rate: 44100
 hop_length: 512
 hop_length: 512
@@ -38,8 +38,8 @@ data:
   train_dataset: ${train_dataset}
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
   num_workers: 4
-  batch_size: 128
-  val_batch_size: 4
+  batch_size: 64
+  val_batch_size: 64
 
 
 # Model Configuration
 # Model Configuration
 model:
 model:
@@ -58,33 +58,24 @@ model:
   encoder:
   encoder:
     _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
     _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
     input_channels: ${num_mels}
     input_channels: ${num_mels}
-    residual_channels: 512
+    residual_channels: 768
     residual_layers: 20
     residual_layers: 20
     dilation_cycle: 4
     dilation_cycle: 4
   
   
   quantizer:
   quantizer:
     _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
     _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
-    input_dim: 512
-    n_codebooks: 8
-    n_groups: 1
+    input_dim: 768
+    n_codebooks: 1
+    n_groups: 8
     levels: [8, 5, 5, 5]
     levels: [8, 5, 5, 5]
-  
-  decoder:
-    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    output_channels: ${num_mels}
-    residual_channels: 512
-    residual_layers: 20
-    dilation_cycle: 4
 
 
-  reflow:
+  decoder:
     _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
     _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    input_channels: ${num_mels}
     output_channels: ${num_mels}
     output_channels: ${num_mels}
-    residual_channels: 512
-    condition_channels: 512
+    residual_channels: 768
     residual_layers: 20
     residual_layers: 20
     dilation_cycle: 4
     dilation_cycle: 4
-    is_diffusion: true
+    condition_channels: 768
 
 
   vocoder:
   vocoder:
     _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
     _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
@@ -129,4 +120,3 @@ callbacks:
       - encoder
       - encoder
       - decoder
       - decoder
       - quantizer
       - quantizer
-      - reflow

+ 1 - 1
fish_speech/datasets/vqgan.py

@@ -122,7 +122,7 @@ class VQGANDataModule(LightningDataModule):
     def val_dataloader(self):
     def val_dataloader(self):
         return DataLoader(
         return DataLoader(
             self.val_dataset,
             self.val_dataset,
-            batch_size=self.batch_size,
+            batch_size=self.val_batch_size,
             collate_fn=VQGANCollator(),
             collate_fn=VQGANCollator(),
             num_workers=self.num_workers,
             num_workers=self.num_workers,
         )
         )

+ 59 - 109
fish_speech/models/vqgan/lit_module.py

@@ -1,6 +1,5 @@
-import itertools
-from dataclasses import dataclass
-from typing import Any, Callable, Literal, Optional
+import math
+from typing import Any, Callable
 
 
 import lightning as L
 import lightning as L
 import torch
 import torch
@@ -22,7 +21,7 @@ class VQGAN(L.LightningModule):
         encoder: WaveNet,
         encoder: WaveNet,
         quantizer: nn.Module,
         quantizer: nn.Module,
         decoder: WaveNet,
         decoder: WaveNet,
-        reflow: nn.Module,
+        # reflow: nn.Module,
         vocoder: nn.Module,
         vocoder: nn.Module,
         mel_transform: nn.Module,
         mel_transform: nn.Module,
         weight_reflow: float = 1.0,
         weight_reflow: float = 1.0,
@@ -44,7 +43,7 @@ class VQGAN(L.LightningModule):
         self.quantizer = quantizer
         self.quantizer = quantizer
         self.decoder = decoder
         self.decoder = decoder
         self.vocoder = vocoder
         self.vocoder = vocoder
-        self.reflow = reflow
+        # self.reflow = reflow
         self.mel_transform = mel_transform
         self.mel_transform = mel_transform
 
 
         # Freeze vocoder
         # Freeze vocoder
@@ -122,51 +121,21 @@ class VQGAN(L.LightningModule):
         vq_recon_features = vq_result.z * mel_masks_float_conv
         vq_recon_features = vq_result.z * mel_masks_float_conv
 
 
         # VQ Decode
         # VQ Decode
-        gen_mel = self.decoder(vq_recon_features) * mel_masks_float_conv
+        gen_mel = (
+            self.decoder(
+                torch.randn_like(vq_recon_features) * mel_masks_float_conv,
+                condition=vq_recon_features,
+            )
+            * mel_masks_float_conv
+        )
 
 
         # Mel Loss
         # Mel Loss
         loss_mel = (gen_mel - gt_mels).abs().mean(
         loss_mel = (gen_mel - gt_mels).abs().mean(
             dim=1, keepdim=True
             dim=1, keepdim=True
         ).sum() / mel_masks_float_conv.sum()
         ).sum() / mel_masks_float_conv.sum()
 
 
-        # Reflow, given x_1_aux, we want to reconstruct x_1
-        x_1 = self.norm_spec(gt_mels)
-        t = torch.rand(gt_mels.shape[0], device=gt_mels.device, dtype=torch.float32)
-        t = torch.clamp(t, 1e-6, 1 - 1e-6)  # Avoid 0 and 1
-        x_0 = torch.randn_like(x_1)
-
-        # X_t = t * X_1 + (1 - t) * X_0
-        x_t = x_0 + t[:, None, None] * (x_1 - x_0)
-
-        v_pred = self.reflow(
-            x_t,
-            1000 * t,
-            vq_recon_features.detach(),  # Stop gradients, avoid reflow to destroy the VQ
-        )
-
-        # Log L2 loss with
-        with torch.autocast(device_type=gt_mels.device.type, dtype=torch.float32):
-            weights = (
-                0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
-            )
-            assert (
-                torch.isnan(weights).any() == False
-                and torch.isinf(weights).any() == False
-            ), "Found NaN or Inf in weights."
-
-            loss_reflow = weights[:, None, None] * F.mse_loss(
-                x_1 - x_0, v_pred, reduction="none"
-            )
-            loss_reflow = (loss_reflow * mel_masks_float_conv).mean(
-                dim=1
-            ).sum() / mel_masks_float_conv.sum()
-
         # Total loss
         # Total loss
-        loss = (
-            self.weight_vq * loss_vq
-            + self.weight_mel * loss_mel
-            + self.weight_reflow * loss_reflow
-        )
+        loss = self.weight_vq * loss_vq + self.weight_mel * loss_mel
 
 
         # Log losses
         # Log losses
         self.log(
         self.log(
@@ -193,14 +162,6 @@ class VQGAN(L.LightningModule):
             prog_bar=False,
             prog_bar=False,
             logger=True,
             logger=True,
         )
         )
-        self.log(
-            "train/generator/loss_reflow",
-            loss_reflow,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-        )
 
 
         return loss
         return loss
 
 
@@ -223,7 +184,13 @@ class VQGAN(L.LightningModule):
         vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
         vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
 
 
         # VQ Decode
         # VQ Decode
-        gen_aux_mels = self.decoder(vq_recon_features) * mel_masks_float_conv
+        gen_aux_mels = (
+            self.decoder(
+                torch.randn_like(vq_recon_features) * mel_masks_float_conv,
+                condition=vq_recon_features,
+            )
+            * mel_masks_float_conv
+        )
         loss_mel = (gen_aux_mels - gt_mels).abs().mean(
         loss_mel = (gen_aux_mels - gt_mels).abs().mean(
             dim=1, keepdim=True
             dim=1, keepdim=True
         ).sum() / mel_masks_float_conv.sum()
         ).sum() / mel_masks_float_conv.sum()
@@ -238,45 +205,8 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
             sync_dist=True,
         )
         )
 
 
-        # Reflow inference
-        t_start = 0.0
-
-        x_1 = self.norm_spec(gen_aux_mels)
-        x_0 = torch.randn_like(x_1)
-        gen_reflow_mels = (1 - t_start) * x_0 + t_start * x_1
-
-        t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
-        dt = (1.0 - t_start) / self.reflow_inference_steps
-
-        for _ in range(self.reflow_inference_steps):
-            gen_reflow_mels += (
-                self.reflow(
-                    gen_reflow_mels,
-                    1000 * t,
-                    vq_recon_features,
-                )
-                * dt
-            )
-            t += dt
-
-        gen_reflow_mels = self.denorm_spec(gen_reflow_mels) * mel_masks_float_conv
-        loss_reflow_mel = (gen_reflow_mels - gt_mels).abs().mean(
-            dim=1, keepdim=True
-        ).sum() / mel_masks_float_conv.sum()
-
-        self.log(
-            "val/loss_reflow_mel",
-            loss_reflow_mel,
-            on_step=False,
-            on_epoch=True,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
         recon_audios = self.vocoder(gt_mels)
         recon_audios = self.vocoder(gt_mels)
         gen_aux_audios = self.vocoder(gen_aux_mels)
         gen_aux_audios = self.vocoder(gen_aux_mels)
-        gen_reflow_audios = self.vocoder(gen_reflow_mels)
 
 
         # only log the first batch
         # only log the first batch
         if batch_idx != 0:
         if batch_idx != 0:
@@ -285,36 +215,33 @@ class VQGAN(L.LightningModule):
         for idx, (
         for idx, (
             gt_mel,
             gt_mel,
             gen_aux_mel,
             gen_aux_mel,
-            gen_reflow_mel,
             audio,
             audio,
             gen_aux_audio,
             gen_aux_audio,
-            gen_reflow_audio,
             recon_audio,
             recon_audio,
             audio_len,
             audio_len,
         ) in enumerate(
         ) in enumerate(
             zip(
             zip(
                 gt_mels,
                 gt_mels,
                 gen_aux_mels,
                 gen_aux_mels,
-                gen_reflow_mels,
-                audios.float(),
-                gen_aux_audios.float(),
-                gen_reflow_audios.float(),
-                recon_audios.float(),
+                audios.cpu().float(),
+                gen_aux_audios.cpu().float(),
+                recon_audios.cpu().float(),
                 audio_lengths,
                 audio_lengths,
             )
             )
         ):
         ):
+            if idx > 4:
+                break
+
             mel_len = audio_len // self.mel_transform.hop_length
             mel_len = audio_len // self.mel_transform.hop_length
 
 
             image_mels = plot_mel(
             image_mels = plot_mel(
                 [
                 [
                     gt_mel[:, :mel_len],
                     gt_mel[:, :mel_len],
                     gen_aux_mel[:, :mel_len],
                     gen_aux_mel[:, :mel_len],
-                    gen_reflow_mel[:, :mel_len],
                 ],
                 ],
                 [
                 [
                     "Ground-Truth",
                     "Ground-Truth",
                     "Auxiliary",
                     "Auxiliary",
-                    "Reflow",
                 ],
                 ],
             )
             )
 
 
@@ -333,11 +260,6 @@ class VQGAN(L.LightningModule):
                                 sample_rate=self.sampling_rate,
                                 sample_rate=self.sampling_rate,
                                 caption="aux",
                                 caption="aux",
                             ),
                             ),
-                            wandb.Audio(
-                                gen_reflow_audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="reflow",
-                            ),
                             wandb.Audio(
                             wandb.Audio(
                                 recon_audio[0, :audio_len],
                                 recon_audio[0, :audio_len],
                                 sample_rate=self.sampling_rate,
                                 sample_rate=self.sampling_rate,
@@ -365,12 +287,6 @@ class VQGAN(L.LightningModule):
                     self.global_step,
                     self.global_step,
                     sample_rate=self.sampling_rate,
                     sample_rate=self.sampling_rate,
                 )
                 )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/reflow",
-                    gen_reflow_audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
                 self.logger.experiment.add_audio(
                 self.logger.experiment.add_audio(
                     f"sample-{idx}/wavs/recon",
                     f"sample-{idx}/wavs/recon",
                     recon_audio[0, :audio_len],
                     recon_audio[0, :audio_len],
@@ -379,3 +295,37 @@ class VQGAN(L.LightningModule):
                 )
                 )
 
 
             plt.close(image_mels)
             plt.close(image_mels)
+
+    def encode(self, audios, audio_lengths):
+        audios = audios.float()
+
+        gt_mels = self.mel_transform(audios)
+        mel_lengths = audio_lengths // self.mel_transform.hop_length
+        mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+        gt_mels = gt_mels * mel_masks_float_conv
+
+        # Encode
+        encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
+        feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
+
+        return self.quantizer.encode(encoded_features), feature_lengths
+
+    def decode(self, indices, feature_lengths, return_audios=False):
+        factor = math.prod(self.quantizer.downsample_factor)
+        mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+
+        z = self.quantizer.decode(indices) * mel_masks_float_conv
+        gen_mel = (
+            self.decoder(
+                torch.randn_like(z) * mel_masks_float_conv,
+                condition=z,
+            )
+            * mel_masks_float_conv
+        )
+
+        if return_audios:
+            return self.vocoder(gen_mel)
+
+        return gen_mel

+ 11 - 7
fish_speech/models/vqgan/modules/fsq.py

@@ -1,12 +1,9 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Union
 
 
-import numpy as np
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 from einops import rearrange
 from einops import rearrange
-from torch.nn.utils import weight_norm
 from vector_quantize_pytorch import GroupedResidualFSQ
 from vector_quantize_pytorch import GroupedResidualFSQ
 
 
 from .firefly import ConvNeXtBlock
 from .firefly import ConvNeXtBlock
@@ -106,10 +103,17 @@ class DownsampleFiniteScalarQuantize(nn.Module):
 
 
         return result
         return result
 
 
-    # def from_codes(self, codes: torch.Tensor):
-    #     z_q, z_p, codes = self.residual_fsq.get_output_from_indices(codes)
-    #     z_q = self.upsample(z_q)
-    #     return z_q, z_p, codes
+    def encode(self, z):
+        z = self.downsample(z)
+        _, indices = self.residual_fsq(z.mT)
+        indices = rearrange(indices, "g b l r -> b (g r) l")
+        return indices
+
+    def decode(self, indices: torch.Tensor):
+        indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
+        z_q = self.residual_fsq.get_output_from_indices(indices)
+        z_q = self.upsample(z_q.mT)
+        return z_q
 
 
     # def from_latents(self, latents: torch.Tensor):
     # def from_latents(self, latents: torch.Tensor):
     #     z_q, z_p, codes = super().from_latents(latents)
     #     z_q, z_p, codes = super().from_latents(latents)

+ 3 - 3
fish_speech/models/vqgan/modules/wavenet.py

@@ -89,7 +89,6 @@ class ResidualBlock(nn.Module):
         residual_channels,
         residual_channels,
         use_linear_bias=False,
         use_linear_bias=False,
         dilation=1,
         dilation=1,
-        has_condition=True,
         condition_channels=None,
         condition_channels=None,
     ):
     ):
         super(ResidualBlock, self).__init__()
         super(ResidualBlock, self).__init__()
@@ -102,7 +101,7 @@ class ResidualBlock(nn.Module):
             dilation=dilation,
             dilation=dilation,
         )
         )
 
 
-        if has_condition:
+        if condition_channels is not None:
             self.diffusion_projection = LinearNorm(
             self.diffusion_projection = LinearNorm(
                 residual_channels, residual_channels, use_linear_bias
                 residual_channels, residual_channels, use_linear_bias
             )
             )
@@ -159,6 +158,8 @@ class WaveNet(nn.Module):
         if input_channels is None:
         if input_channels is None:
             input_channels = residual_channels
             input_channels = residual_channels
 
 
+        self.input_channels = input_channels
+
         # Residual layers
         # Residual layers
         self.residual_layers = nn.ModuleList(
         self.residual_layers = nn.ModuleList(
             [
             [
@@ -166,7 +167,6 @@ class WaveNet(nn.Module):
                     residual_channels=residual_channels,
                     residual_channels=residual_channels,
                     use_linear_bias=False,
                     use_linear_bias=False,
                     dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
                     dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
-                    has_condition=is_diffusion,
                     condition_channels=condition_channels,
                     condition_channels=condition_channels,
                 )
                 )
                 for i in range(residual_layers)
                 for i in range(residual_layers)

+ 2 - 5
tools/vqgan/extract_vq.py

@@ -11,14 +11,12 @@ import click
 import numpy as np
 import numpy as np
 import torch
 import torch
 import torchaudio
 import torchaudio
-from einops import rearrange
 from hydra import compose, initialize
 from hydra import compose, initialize
 from hydra.utils import instantiate
 from hydra.utils import instantiate
 from lightning import LightningModule
 from lightning import LightningModule
 from loguru import logger
 from loguru import logger
 from omegaconf import OmegaConf
 from omegaconf import OmegaConf
 
 
-from fish_speech.models.vqgan.utils import sequence_mask
 from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
 from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
 
 
 # register eval resolver
 # register eval resolver
@@ -57,7 +55,7 @@ def get_model(
     if "state_dict" in state_dict:
     if "state_dict" in state_dict:
         state_dict = state_dict["state_dict"]
         state_dict = state_dict["state_dict"]
 
 
-    model.load_state_dict(state_dict, strict=True)
+    model.load_state_dict(state_dict, strict=False)
     model.eval()
     model.eval()
     model.cuda()
     model.cuda()
 
 
@@ -90,8 +88,7 @@ def process_batch(files: list[Path], model) -> float:
 
 
     # Calculate lengths
     # Calculate lengths
     with torch.no_grad():
     with torch.no_grad():
-        out = model.encode(audios, audio_lengths)
-        indices, feature_lengths = out.indices, out.feature_lengths
+        indices, feature_lengths = model.encode(audios, audio_lengths)
 
 
     # Save to disk
     # Save to disk
     outputs = indices.cpu().numpy()
     outputs = indices.cpu().numpy()

+ 13 - 7
tools/vqgan/inference.py

@@ -26,14 +26,18 @@ OmegaConf.register_new_resolver("eval", eval)
 @click.option(
 @click.option(
     "--input-path",
     "--input-path",
     "-i",
     "-i",
-    default="data/Genshin/Chinese/派蒙/vo_WYLQ103_10_paimon_04.wav",
+    default="data/sft/Rail_ZH/三月七/1fe0cc6fc3fe3e6d.wav",
     type=click.Path(exists=True, path_type=Path),
     type=click.Path(exists=True, path_type=Path),
 )
 )
 @click.option(
 @click.option(
     "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
     "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
 )
 )
 @click.option("--config-name", "-cfg", default="vqgan_pretrain")
 @click.option("--config-name", "-cfg", default="vqgan_pretrain")
-@click.option("--checkpoint-path", "-ckpt", default="checkpoints/vqgan-v1.pth")
+@click.option(
+    "--checkpoint-path",
+    "-ckpt",
+    default="results/vq-group-fsq-8x1024-wn-20x512-cond-e009/checkpoints/step_000355000.ckpt",
+)
 def main(input_path, output_path, config_name, checkpoint_path):
 def main(input_path, output_path, config_name, checkpoint_path):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
         cfg = compose(config_name=config_name)
         cfg = compose(config_name=config_name)
@@ -45,7 +49,7 @@ def main(input_path, output_path, config_name, checkpoint_path):
     )
     )
     if "state_dict" in state_dict:
     if "state_dict" in state_dict:
         state_dict = state_dict["state_dict"]
         state_dict = state_dict["state_dict"]
-    model.load_state_dict(state_dict, strict=True)
+    model.load_state_dict(state_dict, strict=False)
     model.eval()
     model.eval()
     model.cuda()
     model.cuda()
     logger.info("Restored model from checkpoint")
     logger.info("Restored model from checkpoint")
@@ -67,8 +71,7 @@ def main(input_path, output_path, config_name, checkpoint_path):
         audio_lengths = torch.tensor(
         audio_lengths = torch.tensor(
             [audios.shape[2]], device=model.device, dtype=torch.long
             [audios.shape[2]], device=model.device, dtype=torch.long
         )
         )
-        encoded = model.encode(audios, audio_lengths)
-        indices = encoded.indices[0]
+        indices = model.encode(audios, audio_lengths)[0][0]
 
 
         logger.info(f"Generated indices of shape {indices.shape}")
         logger.info(f"Generated indices of shape {indices.shape}")
 
 
@@ -82,12 +85,15 @@ def main(input_path, output_path, config_name, checkpoint_path):
     else:
     else:
         raise ValueError(f"Unknown input type: {input_path}")
         raise ValueError(f"Unknown input type: {input_path}")
 
 
+    # random destroy 10% of indices
+    # mask = torch.rand_like(indices, dtype=torch.float) > 0.9
+    # indices[mask] = torch.randint(0, 1000, mask.shape, device=indices.device, dtype=indices.dtype)[mask]
+
     # Restore
     # Restore
     feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
     feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
-    decoded = model.decode(
+    fake_audios = model.decode(
         indices=indices[None], feature_lengths=feature_lengths, return_audios=True
         indices=indices[None], feature_lengths=feature_lengths, return_audios=True
     )
     )
-    fake_audios = decoded.audios
     audio_time = fake_audios.shape[-1] / model.sampling_rate
     audio_time = fake_audios.shape[-1] / model.sampling_rate
 
 
     logger.info(
     logger.info(