|
@@ -9,7 +9,7 @@ from fish_speech.models.vqgan.modules.encoders import (
|
|
|
VQEncoder,
|
|
VQEncoder,
|
|
|
)
|
|
)
|
|
|
from fish_speech.models.vqgan.modules.flow import ResidualCouplingBlock
|
|
from fish_speech.models.vqgan.modules.flow import ResidualCouplingBlock
|
|
|
-from fish_speech.models.vqgan.utils import rand_slice_segments
|
|
|
|
|
|
|
+from fish_speech.models.vqgan.utils import rand_slice_segments, sequence_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
class SynthesizerTrn(nn.Module):
|
|
class SynthesizerTrn(nn.Module):
|
|
@@ -105,12 +105,18 @@ class SynthesizerTrn(nn.Module):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x, x_lengths, specs):
|
|
def forward(self, x, x_lengths, specs):
|
|
|
- g = self.enc_spk(specs, x_lengths)
|
|
|
|
|
- x, vq_loss = self.vq(x)
|
|
|
|
|
|
|
+ x = x.mT
|
|
|
|
|
+ spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
|
|
|
|
|
+ specs.dtype
|
|
|
|
|
+ )
|
|
|
|
|
+ x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
|
|
|
|
|
+
|
|
|
|
|
+ g = self.enc_spk(specs, spec_masks)
|
|
|
|
|
+ x, vq_loss = self.vq(x, x_masks)
|
|
|
|
|
|
|
|
- _, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
|
|
|
|
|
- z_q, m_q, logs_q, y_mask = self.enc_q(specs, x_lengths, g=g)
|
|
|
|
|
- z_p = self.flow(z_q, y_mask, g=g, reverse=False)
|
|
|
|
|
|
|
+ _, m_p, logs_p, _, _ = self.enc_p(x, x_masks, g=g)
|
|
|
|
|
+ z_q, m_q, logs_q, _ = self.enc_q(specs, spec_masks, g=g)
|
|
|
|
|
+ z_p = self.flow(z_q, spec_masks, g=g, reverse=False)
|
|
|
|
|
|
|
|
z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
|
|
z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
|
|
|
o = self.dec(z_slice, g=g)
|
|
o = self.dec(z_slice, g=g)
|
|
@@ -118,8 +124,8 @@ class SynthesizerTrn(nn.Module):
|
|
|
return (
|
|
return (
|
|
|
o,
|
|
o,
|
|
|
ids_slice,
|
|
ids_slice,
|
|
|
- x_mask,
|
|
|
|
|
- y_mask,
|
|
|
|
|
|
|
+ x_masks,
|
|
|
|
|
+ spec_masks,
|
|
|
(z_q, z_p),
|
|
(z_q, z_p),
|
|
|
(m_p, logs_p),
|
|
(m_p, logs_p),
|
|
|
(m_q, logs_q),
|
|
(m_q, logs_q),
|
|
@@ -127,21 +133,29 @@ class SynthesizerTrn(nn.Module):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
|
|
def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
|
|
|
- g = self.enc_spk(specs, x_lengths)
|
|
|
|
|
- x, vq_loss = self.vq(x)
|
|
|
|
|
- z_p, m_p, logs_p, h_text, x_mask = self.enc_p(
|
|
|
|
|
- x, x_lengths, g=g, noise_scale=noise_scale
|
|
|
|
|
|
|
+ x = x.mT
|
|
|
|
|
+ spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
|
|
|
|
|
+ specs.dtype
|
|
|
)
|
|
)
|
|
|
- z_p = self.flow(z_p, x_mask, g=g, reverse=True)
|
|
|
|
|
|
|
+ x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
|
|
|
|
|
+ g = self.enc_spk(specs, spec_masks)
|
|
|
|
|
+ x, vq_loss = self.vq(x, x_masks)
|
|
|
|
|
+ z_p, m_p, logs_p, h_text, _ = self.enc_p(
|
|
|
|
|
+ x, x_masks, g=g, noise_scale=noise_scale
|
|
|
|
|
+ )
|
|
|
|
|
+ z_p = self.flow(z_p, x_masks, g=g, reverse=True)
|
|
|
|
|
|
|
|
- o = self.dec((z_p * x_mask)[:, :, :max_len], g=g)
|
|
|
|
|
|
|
+ o = self.dec((z_p * x_masks)[:, :, :max_len], g=g)
|
|
|
return o
|
|
return o
|
|
|
|
|
|
|
|
- def reconstruct(self, x, x_lengths, max_len=None, noise_scale=0.35):
|
|
|
|
|
- g = self.enc_spk(x, x_lengths)
|
|
|
|
|
- z_q, m_q, logs_q, x_mask = self.enc_q(
|
|
|
|
|
- x, x_lengths, g=g, noise_scale=noise_scale
|
|
|
|
|
|
|
+ def reconstruct(self, specs, spec_lengths, max_len=None, noise_scale=0.35):
|
|
|
|
|
+ spec_masks = torch.unsqueeze(sequence_mask(spec_lengths, specs.shape[2]), 1).to(
|
|
|
|
|
+ specs.dtype
|
|
|
|
|
+ )
|
|
|
|
|
+ g = self.enc_spk(specs, spec_masks)
|
|
|
|
|
+ z_q, m_q, logs_q, _ = self.enc_q(
|
|
|
|
|
+ specs, spec_masks, g=g, noise_scale=noise_scale
|
|
|
)
|
|
)
|
|
|
- o = self.dec((z_q * x_mask)[:, :, :max_len], g=g)
|
|
|
|
|
|
|
+ o = self.dec((z_q * spec_masks)[:, :, :max_len], g=g)
|
|
|
|
|
|
|
|
return o
|
|
return o
|