Преглед изворни кода

Update vq training receipe

Lengyue пре 2 година
родитељ
комит
3e6707b7c3

+ 2 - 2
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -84,10 +84,10 @@ model:
   
   denoiser:
     _target_: fish_speech.models.vq_diffusion.convnext_1d.ConvNext1DModel
-    in_channels: 128
+    in_channels: 256
     out_channels: 128
     intermediate_dim: 512
-    condition_dim: 128
+    # condition_dim: 128
     mlp_dim: 2048
     num_layers: 20
     dilation_cycle_length: 2

+ 5 - 3
fish_speech/models/vq_diffusion/convnext_1d.py

@@ -129,7 +129,6 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
         num_layers: int = 20,
         dilation_cycle_length: int = 4,
         time_embedding_type: str = "positional",
-        condition_dim: Optional[int] = None,
     ):
         super().__init__()
 
@@ -157,7 +156,7 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
             timestep_input_dim,
             intermediate_dim,
             act_fn="silu",
-            cond_proj_dim=condition_dim,
+            cond_proj_dim=None,  # No conditional projection for now
         )
 
         # Project to intermediate dim
@@ -219,9 +218,12 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
 
         # 1. time
         t_emb = self.time_proj(timestep)
-        t_emb = self.time_mlp(sample=t_emb[:, None], condition=condition.mT).mT
+        t_emb = self.time_mlp(t_emb)[..., None]
 
         # 2. pre-process
+        if condition is not None:
+            sample = torch.cat([sample, condition], dim=1)
+
         x = self.in_proj(sample)
 
         if sample_mask.ndim == 2:

+ 2 - 1
fish_speech/models/vq_diffusion/lit_module.py

@@ -10,6 +10,7 @@ from diffusers.utils.torch_utils import randn_tensor
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
+from tqdm import tqdm
 
 from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
 from fish_speech.models.vqgan.modules.encoders import (
@@ -183,7 +184,7 @@ class VQDiffusion(L.LightningModule):
         sampled_mels = torch.randn_like(gt_mels)
         self.noise_scheduler_infer.set_timesteps(100)
 
-        for t in self.noise_scheduler_infer.timesteps:
+        for t in tqdm(self.noise_scheduler_infer.timesteps):
             timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
 
             # 1. predict noise model_output