Lengyue 2 лет назад
Родитель
Сommit
7ad4a967f2

+ 2 - 4
fish_speech/configs/hubert_vq.yaml

@@ -8,9 +8,7 @@ project: hubert_vq
 trainer:
 trainer:
   accelerator: gpu
   accelerator: gpu
   devices: 4
   devices: 4
-  strategy:
-    _target_: lightning.pytorch.strategies.DDPStrategy
-    static_graph: true
+  strategy: ddp_find_unused_parameters_true
   precision: 32
   precision: 32
   max_steps: 1_000_000
   max_steps: 1_000_000
   val_check_interval: 5000
   val_check_interval: 5000
@@ -45,7 +43,7 @@ data:
 
 
 # Model Configuration
 # Model Configuration
 model:
 model:
-  _target_: fish_speech.models.vq_diffusion.VQGAN
+  _target_: fish_speech.models.vqgan.VQGAN
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   hop_length: ${hop_length}
   segment_size: 20480
   segment_size: 20480

+ 1 - 3
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -8,9 +8,7 @@ project: hubert_vq_diffusion
 trainer:
 trainer:
   accelerator: gpu
   accelerator: gpu
   devices: 4
   devices: 4
-  strategy:
-    _target_: lightning.pytorch.strategies.DDPStrategy
-    static_graph: true
+  strategy: ddp_find_unused_parameters_true
   gradient_clip_val: 1.0
   gradient_clip_val: 1.0
   gradient_clip_algorithm: 'norm'
   gradient_clip_algorithm: 'norm'
   precision: 16-mixed
   precision: 16-mixed

+ 1 - 1
fish_speech/models/vq_diffusion/lit_module.py

@@ -130,7 +130,7 @@ class VQDiffusion(L.LightningModule):
         model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
         model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
 
 
         # MSE loss without the mask
         # MSE loss without the mask
-        noise_loss = ((model_output * mel_masks - noise * mel_masks) ** 2).sum() / (
+        noise_loss = (torch.abs(model_output * mel_masks - noise * mel_masks)).sum() / (
             mel_masks.sum() * gt_mels.shape[1]
             mel_masks.sum() * gt_mels.shape[1]
         )
         )
 
 

+ 1 - 2
fish_speech/models/vqgan/modules/encoders.py

@@ -129,7 +129,7 @@ class PosteriorEncoder(nn.Module):
     def forward(
     def forward(
         self,
         self,
         x: torch.Tensor,
         x: torch.Tensor,
-        x_lengths: torch.Tensor,
+        x_mask: torch.Tensor,
         g: torch.Tensor,
         g: torch.Tensor,
         noise_scale: float = 1,
         noise_scale: float = 1,
     ):
     ):
@@ -139,7 +139,6 @@ class PosteriorEncoder(nn.Module):
             - x_lengths: :math:`[B, 1]`
             - x_lengths: :math:`[B, 1]`
             - g: :math:`[B, C, 1]`
             - g: :math:`[B, C, 1]`
         """
         """
-        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
         x = self.pre(x) * x_mask
         x = self.pre(x) * x_mask
         x = self.enc(x, x_mask, g=g)
         x = self.enc(x, x_mask, g=g)
         stats = self.proj(x) * x_mask
         stats = self.proj(x) * x_mask

+ 33 - 19
fish_speech/models/vqgan/modules/models.py

@@ -9,7 +9,7 @@ from fish_speech.models.vqgan.modules.encoders import (
     VQEncoder,
     VQEncoder,
 )
 )
 from fish_speech.models.vqgan.modules.flow import ResidualCouplingBlock
 from fish_speech.models.vqgan.modules.flow import ResidualCouplingBlock
-from fish_speech.models.vqgan.utils import rand_slice_segments
+from fish_speech.models.vqgan.utils import rand_slice_segments, sequence_mask
 
 
 
 
 class SynthesizerTrn(nn.Module):
 class SynthesizerTrn(nn.Module):
@@ -105,12 +105,18 @@ class SynthesizerTrn(nn.Module):
         )
         )
 
 
     def forward(self, x, x_lengths, specs):
     def forward(self, x, x_lengths, specs):
-        g = self.enc_spk(specs, x_lengths)
-        x, vq_loss = self.vq(x)
+        x = x.mT
+        spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
+            specs.dtype
+        )
+        x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
+
+        g = self.enc_spk(specs, spec_masks)
+        x, vq_loss = self.vq(x, x_masks)
 
 
-        _, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
-        z_q, m_q, logs_q, y_mask = self.enc_q(specs, x_lengths, g=g)
-        z_p = self.flow(z_q, y_mask, g=g, reverse=False)
+        _, m_p, logs_p, _, _ = self.enc_p(x, x_masks, g=g)
+        z_q, m_q, logs_q, _ = self.enc_q(specs, spec_masks, g=g)
+        z_p = self.flow(z_q, spec_masks, g=g, reverse=False)
 
 
         z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
         z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
         o = self.dec(z_slice, g=g)
         o = self.dec(z_slice, g=g)
@@ -118,8 +124,8 @@ class SynthesizerTrn(nn.Module):
         return (
         return (
             o,
             o,
             ids_slice,
             ids_slice,
-            x_mask,
-            y_mask,
+            x_masks,
+            spec_masks,
             (z_q, z_p),
             (z_q, z_p),
             (m_p, logs_p),
             (m_p, logs_p),
             (m_q, logs_q),
             (m_q, logs_q),
@@ -127,21 +133,29 @@ class SynthesizerTrn(nn.Module):
         )
         )
 
 
     def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
     def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
-        g = self.enc_spk(specs, x_lengths)
-        x, vq_loss = self.vq(x)
-        z_p, m_p, logs_p, h_text, x_mask = self.enc_p(
-            x, x_lengths, g=g, noise_scale=noise_scale
+        x = x.mT
+        spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
+            specs.dtype
         )
         )
-        z_p = self.flow(z_p, x_mask, g=g, reverse=True)
+        x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
+        g = self.enc_spk(specs, spec_masks)
+        x, vq_loss = self.vq(x, x_masks)
+        z_p, m_p, logs_p, h_text, _ = self.enc_p(
+            x, x_masks, g=g, noise_scale=noise_scale
+        )
+        z_p = self.flow(z_p, x_masks, g=g, reverse=True)
 
 
-        o = self.dec((z_p * x_mask)[:, :, :max_len], g=g)
+        o = self.dec((z_p * x_masks)[:, :, :max_len], g=g)
         return o
         return o
 
 
-    def reconstruct(self, x, x_lengths, max_len=None, noise_scale=0.35):
-        g = self.enc_spk(x, x_lengths)
-        z_q, m_q, logs_q, x_mask = self.enc_q(
-            x, x_lengths, g=g, noise_scale=noise_scale
+    def reconstruct(self, specs, spec_lengths, max_len=None, noise_scale=0.35):
+        spec_masks = torch.unsqueeze(sequence_mask(spec_lengths, specs.shape[2]), 1).to(
+            specs.dtype
+        )
+        g = self.enc_spk(specs, spec_masks)
+        z_q, m_q, logs_q, _ = self.enc_q(
+            specs, spec_masks, g=g, noise_scale=noise_scale
         )
         )
-        o = self.dec((z_q * x_mask)[:, :, :max_len], g=g)
+        o = self.dec((z_q * spec_masks)[:, :, :max_len], g=g)
 
 
         return o
         return o

+ 1 - 1
fish_speech/models/vqgan/utils.py

@@ -67,7 +67,7 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
     if x_lengths is None:
     if x_lengths is None:
         x_lengths = t
         x_lengths = t
     ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
     ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
-    ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+    ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
     ret = slice_segments(x, ids_str, segment_size)
     ret = slice_segments(x, ids_str, segment_size)
     return ret, ids_str
     return ret, ids_str