|
|
@@ -77,6 +77,7 @@ class TextEncoder(nn.Module):
|
|
|
).to(y.dtype)
|
|
|
text = self.text_embedding(text).transpose(1, 2)
|
|
|
text = self.encoder_text(text * text_mask, text_mask)
|
|
|
+
|
|
|
y = self.mrte(y, y_mask, text, text_mask, ge)
|
|
|
|
|
|
y = self.encoder2(y * y_mask, y_mask)
|
|
|
@@ -85,25 +86,6 @@ class TextEncoder(nn.Module):
|
|
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
|
|
return y, m, logs, y_mask
|
|
|
|
|
|
- def extract_latent(self, x):
|
|
|
- x = self.ssl_proj(x)
|
|
|
- quantized, codes, commit_loss, quantized_list = self.quantizer(x)
|
|
|
- return codes.transpose(0, 1)
|
|
|
-
|
|
|
- def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
|
|
|
- quantized = self.quantizer.decode(codes)
|
|
|
-
|
|
|
- y = self.vq_proj(quantized) * y_mask
|
|
|
- y = self.encoder_ssl(y * y_mask, y_mask)
|
|
|
-
|
|
|
- y = self.mrte(y, y_mask, refer, refer_mask, ge)
|
|
|
-
|
|
|
- y = self.encoder2(y * y_mask, y_mask)
|
|
|
-
|
|
|
- stats = self.proj(y) * y_mask
|
|
|
- m, logs = torch.split(stats, self.out_channels, dim=1)
|
|
|
- return y, m, logs, y_mask, quantized
|
|
|
-
|
|
|
|
|
|
class ResidualCouplingBlock(nn.Module):
|
|
|
def __init__(
|
|
|
@@ -430,6 +412,8 @@ class SynthesizerTrn(nn.Module):
|
|
|
upsample_kernel_sizes,
|
|
|
gin_channels=0,
|
|
|
codebook_size=264,
|
|
|
+ vq_mask_ratio=0.0,
|
|
|
+ ref_mask_ratio=0.0,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
@@ -449,6 +433,8 @@ class SynthesizerTrn(nn.Module):
|
|
|
self.upsample_kernel_sizes = upsample_kernel_sizes
|
|
|
self.segment_size = segment_size
|
|
|
self.gin_channels = gin_channels
|
|
|
+ self.vq_mask_ratio = vq_mask_ratio
|
|
|
+ self.ref_mask_ratio = ref_mask_ratio
|
|
|
|
|
|
self.enc_p = TextEncoder(
|
|
|
inter_channels,
|
|
|
@@ -498,8 +484,34 @@ class SynthesizerTrn(nn.Module):
|
|
|
commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
|
|
|
).to(gt_specs.dtype)
|
|
|
ge = self.ref_enc(gt_specs * y_mask, y_mask)
|
|
|
+
|
|
|
+ if self.training and self.ref_mask_ratio > 0:
|
|
|
+ bs = audio.size(0)
|
|
|
+ mask_speaker_len = int(bs * self.ref_mask_ratio)
|
|
|
+ mask_indices = torch.randperm(bs)[:mask_speaker_len]
|
|
|
+ audio[mask_indices] = 0
|
|
|
+
|
|
|
quantized = self.vq(audio, audio_lengths)
|
|
|
- quantized = F.interpolate(quantized, size=gt_specs.size(-1), mode="nearest")
|
|
|
+
|
|
|
+ # Block masking, block_size = 4
|
|
|
+ block_size = 4
|
|
|
+ if self.training and self.vq_mask_ratio > 0:
|
|
|
+ reduced_length = quantized.size(-1) // block_size
|
|
|
+ mask_length = int(reduced_length * self.vq_mask_ratio)
|
|
|
+ mask_indices = torch.randperm(reduced_length)[:mask_length]
|
|
|
+ short_mask = torch.zeros(
|
|
|
+ quantized.size(0),
|
|
|
+ quantized.size(1),
|
|
|
+ reduced_length,
|
|
|
+ device=quantized.device,
|
|
|
+ dtype=torch.float,
|
|
|
+ )
|
|
|
+ short_mask[:, :, mask_indices] = 1.0
|
|
|
+ long_mask = short_mask.repeat_interleave(block_size, dim=-1)
|
|
|
+ long_mask = F.interpolate(
|
|
|
+ long_mask, size=quantized.size(-1), mode="nearest"
|
|
|
+ )
|
|
|
+ quantized = quantized.masked_fill(long_mask > 0.5, 0)
|
|
|
|
|
|
x, m_p, logs_p, y_mask = self.enc_p(
|
|
|
quantized, gt_spec_lengths, text, text_lengths, ge
|
|
|
@@ -621,23 +633,23 @@ if __name__ == "__main__":
|
|
|
# Try to load the model
|
|
|
print(f"Loading model from {ckpt}")
|
|
|
checkpoint = torch.load(ckpt, map_location="cpu", weights_only=True)["model"]
|
|
|
- d_checkpoint = torch.load(
|
|
|
- "checkpoints/Bert-VITS2/D_0.pth", map_location="cpu", weights_only=True
|
|
|
- )["model"]
|
|
|
- print(checkpoint.keys())
|
|
|
+ # d_checkpoint = torch.load(
|
|
|
+ # "checkpoints/Bert-VITS2/D_0.pth", map_location="cpu", weights_only=True
|
|
|
+ # )["model"]
|
|
|
+ # print(checkpoint.keys())
|
|
|
|
|
|
checkpoint.pop("dec.cond.weight")
|
|
|
checkpoint.pop("enc_q.enc.cond_layer.weight_v")
|
|
|
|
|
|
- new_checkpoint = {}
|
|
|
- for k, v in checkpoint.items():
|
|
|
- new_checkpoint["generator." + k] = v
|
|
|
+ # new_checkpoint = {}
|
|
|
+ # for k, v in checkpoint.items():
|
|
|
+ # new_checkpoint["generator." + k] = v
|
|
|
|
|
|
- for k, v in d_checkpoint.items():
|
|
|
- new_checkpoint["discriminator." + k] = v
|
|
|
+ # for k, v in d_checkpoint.items():
|
|
|
+ # new_checkpoint["discriminator." + k] = v
|
|
|
|
|
|
- torch.save(new_checkpoint, "checkpoints/Bert-VITS2/ensemble.pth")
|
|
|
- exit()
|
|
|
+ # torch.save(new_checkpoint, "checkpoints/Bert-VITS2/ensemble.pth")
|
|
|
+ # exit()
|
|
|
|
|
|
print(model.load_state_dict(checkpoint, strict=False))
|
|
|
|
|
|
@@ -672,6 +684,6 @@ if __name__ == "__main__":
|
|
|
print(o.size(), y_mask.size(), z.size(), z_p.size(), m_p.size(), logs_p.size())
|
|
|
|
|
|
# Save output
|
|
|
- import soundfile as sf
|
|
|
+ # import soundfile as sf
|
|
|
|
|
|
- sf.write("output.wav", o.squeeze().detach().numpy(), 32000)
|
|
|
+ # sf.write("output.wav", o.squeeze().detach().numpy(), 32000)
|