Lengyue 2 년 전
부모
커밋
c232a45011
1개의 변경된 파일9개의 추가작업 그리고 0개의 파일을 삭제
  1. 9 0
      fish_speech/models/vqgan/modules/models.py

+ 9 - 0
fish_speech/models/vqgan/modules/models.py

@@ -106,6 +106,15 @@ class SynthesizerTrn(nn.Module):
 
     def forward(self, x, x_lengths, specs):
         x = x.mT
+
+        min_length = min(x.shape[2], specs.shape[2])
+        if min_length % 2 != 0:
+            min_length -= 1
+
+        x = x[:, :, :min_length]
+        specs = specs[:, :, :min_length]
+        x_lengths = torch.clamp(x_lengths, max=min_length)
+
         spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
             specs.dtype
         )