Lengyue 2 лет назад
Родитель
Сommit
3e6707b7c3

+ 2 - 2
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -84,10 +84,10 @@ model:
   
   
   denoiser:
   denoiser:
     _target_: fish_speech.models.vq_diffusion.convnext_1d.ConvNext1DModel
     _target_: fish_speech.models.vq_diffusion.convnext_1d.ConvNext1DModel
-    in_channels: 128
+    in_channels: 256
     out_channels: 128
     out_channels: 128
     intermediate_dim: 512
     intermediate_dim: 512
-    condition_dim: 128
+    # condition_dim: 128
     mlp_dim: 2048
     mlp_dim: 2048
     num_layers: 20
     num_layers: 20
     dilation_cycle_length: 2
     dilation_cycle_length: 2

+ 5 - 3
fish_speech/models/vq_diffusion/convnext_1d.py

@@ -129,7 +129,6 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
         num_layers: int = 20,
         num_layers: int = 20,
         dilation_cycle_length: int = 4,
         dilation_cycle_length: int = 4,
         time_embedding_type: str = "positional",
         time_embedding_type: str = "positional",
-        condition_dim: Optional[int] = None,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
@@ -157,7 +156,7 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
             timestep_input_dim,
             timestep_input_dim,
             intermediate_dim,
             intermediate_dim,
             act_fn="silu",
             act_fn="silu",
-            cond_proj_dim=condition_dim,
+            cond_proj_dim=None,  # No conditional projection for now
         )
         )
 
 
         # Project to intermediate dim
         # Project to intermediate dim
@@ -219,9 +218,12 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
 
 
         # 1. time
         # 1. time
         t_emb = self.time_proj(timestep)
         t_emb = self.time_proj(timestep)
-        t_emb = self.time_mlp(sample=t_emb[:, None], condition=condition.mT).mT
+        t_emb = self.time_mlp(t_emb)[..., None]
 
 
         # 2. pre-process
         # 2. pre-process
+        if condition is not None:
+            sample = torch.cat([sample, condition], dim=1)
+
         x = self.in_proj(sample)
         x = self.in_proj(sample)
 
 
         if sample_mask.ndim == 2:
         if sample_mask.ndim == 2:

+ 2 - 1
fish_speech/models/vq_diffusion/lit_module.py

@@ -10,6 +10,7 @@ from diffusers.utils.torch_utils import randn_tensor
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from matplotlib import pyplot as plt
 from torch import nn
 from torch import nn
+from tqdm import tqdm
 
 
 from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
 from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
 from fish_speech.models.vqgan.modules.encoders import (
 from fish_speech.models.vqgan.modules.encoders import (
@@ -183,7 +184,7 @@ class VQDiffusion(L.LightningModule):
         sampled_mels = torch.randn_like(gt_mels)
         sampled_mels = torch.randn_like(gt_mels)
         self.noise_scheduler_infer.set_timesteps(100)
         self.noise_scheduler_infer.set_timesteps(100)
 
 
-        for t in self.noise_scheduler_infer.timesteps:
+        for t in tqdm(self.noise_scheduler_infer.timesteps):
             timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
             timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
 
 
             # 1. predict noise model_output
             # 1. predict noise model_output