|
@@ -106,6 +106,15 @@ class SynthesizerTrn(nn.Module):
|
|
|
|
|
|
|
|
def forward(self, x, x_lengths, specs):
|
|
def forward(self, x, x_lengths, specs):
|
|
|
x = x.mT
|
|
x = x.mT
|
|
|
|
|
+
|
|
|
|
|
+ min_length = min(x.shape[2], specs.shape[2])
|
|
|
|
|
+ if min_length % 2 != 0:
|
|
|
|
|
+ min_length -= 1
|
|
|
|
|
+
|
|
|
|
|
+ x = x[:, :, :min_length]
|
|
|
|
|
+ specs = specs[:, :, :min_length]
|
|
|
|
|
+ x_lengths = torch.clamp(x_lengths, max=min_length)
|
|
|
|
|
+
|
|
|
spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
|
|
spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
|
|
|
specs.dtype
|
|
specs.dtype
|
|
|
)
|
|
)
|