Lengyue 2 år sedan
förälder
incheckning
c232a45011
1 ändrade filer med 9 tillägg och 0 borttagningar
  1. 9 0
      fish_speech/models/vqgan/modules/models.py

+ 9 - 0
fish_speech/models/vqgan/modules/models.py

@@ -106,6 +106,15 @@ class SynthesizerTrn(nn.Module):
 
     def forward(self, x, x_lengths, specs):
         x = x.mT
+
+        min_length = min(x.shape[2], specs.shape[2])
+        if min_length % 2 != 0:
+            min_length -= 1
+
+        x = x[:, :, :min_length]
+        specs = specs[:, :, :min_length]
+        x_lengths = torch.clamp(x_lengths, max=min_length)
+
         spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
             specs.dtype
         )