Lengyue 1 год назад
Родитель
Сommit
c7505b77d6

+ 4 - 1
fish_speech/configs/vits_decoder.yaml

@@ -34,6 +34,7 @@ train_dataset:
   hop_length: ${hop_length}
   suffix: ".lab"
   tokenizer: ${tokenizer}
+  sentence_mask_ratio: 0.2
 
 val_dataset:
   _target_: fish_speech.datasets.vits.VITSDataset
@@ -80,6 +81,8 @@ model:
     upsample_initial_channel: 512
     upsample_kernel_sizes: [16, 16, 8, 2, 2]
     gin_channels: 512
+    vq_mask_ratio: 0.2
+    ref_mask_ratio: 0.2
 
   discriminator:
     _target_: fish_speech.models.vits_decoder.modules.models.EnsembledDiscriminator
@@ -110,7 +113,7 @@ model:
   lr_scheduler:
     _target_: torch.optim.lr_scheduler.ExponentialLR
     _partial_: true
-    gamma: 0.999875
+    gamma: 0.999999
 
 callbacks:
   grad_norm_monitor:

+ 8 - 1
fish_speech/datasets/vits.py

@@ -1,3 +1,4 @@
+import random
 from dataclasses import dataclass
 from pathlib import Path
 from typing import Optional
@@ -26,6 +27,7 @@ class VITSDataset(Dataset):
         min_duration: float = 1.5,
         max_duration: float = 30.0,
         suffix: str = ".lab",
+        sentence_mask_ratio: float = 0.0,
     ):
         super().__init__()
 
@@ -43,6 +45,7 @@ class VITSDataset(Dataset):
         self.max_duration = max_duration
         self.tokenizer = tokenizer
         self.suffix = suffix
+        self.sentence_mask_ratio = sentence_mask_ratio
 
     def __len__(self):
         return len(self.files)
@@ -68,7 +71,11 @@ class VITSDataset(Dataset):
         if max_value > 1.0:
             audio = audio / max_value
 
-        text = text_file.read_text(encoding="utf-8")
+        if random.random() < self.sentence_mask_ratio:
+            text = "-"
+        else:
+            text = text_file.read_text(encoding="utf-8")
+
         input_ids = self.tokenizer(text, return_tensors="pt").input_ids.squeeze(0)
 
         return {

+ 45 - 33
fish_speech/models/vits_decoder/modules/models.py

@@ -77,6 +77,7 @@ class TextEncoder(nn.Module):
         ).to(y.dtype)
         text = self.text_embedding(text).transpose(1, 2)
         text = self.encoder_text(text * text_mask, text_mask)
+
         y = self.mrte(y, y_mask, text, text_mask, ge)
 
         y = self.encoder2(y * y_mask, y_mask)
@@ -85,25 +86,6 @@ class TextEncoder(nn.Module):
         m, logs = torch.split(stats, self.out_channels, dim=1)
         return y, m, logs, y_mask
 
-    def extract_latent(self, x):
-        x = self.ssl_proj(x)
-        quantized, codes, commit_loss, quantized_list = self.quantizer(x)
-        return codes.transpose(0, 1)
-
-    def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
-        quantized = self.quantizer.decode(codes)
-
-        y = self.vq_proj(quantized) * y_mask
-        y = self.encoder_ssl(y * y_mask, y_mask)
-
-        y = self.mrte(y, y_mask, refer, refer_mask, ge)
-
-        y = self.encoder2(y * y_mask, y_mask)
-
-        stats = self.proj(y) * y_mask
-        m, logs = torch.split(stats, self.out_channels, dim=1)
-        return y, m, logs, y_mask, quantized
-
 
 class ResidualCouplingBlock(nn.Module):
     def __init__(
@@ -430,6 +412,8 @@ class SynthesizerTrn(nn.Module):
         upsample_kernel_sizes,
         gin_channels=0,
         codebook_size=264,
+        vq_mask_ratio=0.0,
+        ref_mask_ratio=0.0,
     ):
         super().__init__()
 
@@ -449,6 +433,8 @@ class SynthesizerTrn(nn.Module):
         self.upsample_kernel_sizes = upsample_kernel_sizes
         self.segment_size = segment_size
         self.gin_channels = gin_channels
+        self.vq_mask_ratio = vq_mask_ratio
+        self.ref_mask_ratio = ref_mask_ratio
 
         self.enc_p = TextEncoder(
             inter_channels,
@@ -498,8 +484,34 @@ class SynthesizerTrn(nn.Module):
             commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
         ).to(gt_specs.dtype)
         ge = self.ref_enc(gt_specs * y_mask, y_mask)
+
+        if self.training and self.ref_mask_ratio > 0:
+            bs = audio.size(0)
+            mask_speaker_len = int(bs * self.ref_mask_ratio)
+            mask_indices = torch.randperm(bs)[:mask_speaker_len]
+            audio[mask_indices] = 0
+
         quantized = self.vq(audio, audio_lengths)
-        quantized = F.interpolate(quantized, size=gt_specs.size(-1), mode="nearest")
+
+        # Block masking, block_size = 4
+        block_size = 4
+        if self.training and self.vq_mask_ratio > 0:
+            reduced_length = quantized.size(-1) // block_size
+            mask_length = int(reduced_length * self.vq_mask_ratio)
+            mask_indices = torch.randperm(reduced_length)[:mask_length]
+            short_mask = torch.zeros(
+                quantized.size(0),
+                quantized.size(1),
+                reduced_length,
+                device=quantized.device,
+                dtype=torch.float,
+            )
+            short_mask[:, :, mask_indices] = 1.0
+            long_mask = short_mask.repeat_interleave(block_size, dim=-1)
+            long_mask = F.interpolate(
+                long_mask, size=quantized.size(-1), mode="nearest"
+            )
+            quantized = quantized.masked_fill(long_mask > 0.5, 0)
 
         x, m_p, logs_p, y_mask = self.enc_p(
             quantized, gt_spec_lengths, text, text_lengths, ge
@@ -621,23 +633,23 @@ if __name__ == "__main__":
     # Try to load the model
     print(f"Loading model from {ckpt}")
     checkpoint = torch.load(ckpt, map_location="cpu", weights_only=True)["model"]
-    d_checkpoint = torch.load(
-        "checkpoints/Bert-VITS2/D_0.pth", map_location="cpu", weights_only=True
-    )["model"]
-    print(checkpoint.keys())
+    # d_checkpoint = torch.load(
+    #     "checkpoints/Bert-VITS2/D_0.pth", map_location="cpu", weights_only=True
+    # )["model"]
+    # print(checkpoint.keys())
 
     checkpoint.pop("dec.cond.weight")
     checkpoint.pop("enc_q.enc.cond_layer.weight_v")
 
-    new_checkpoint = {}
-    for k, v in checkpoint.items():
-        new_checkpoint["generator." + k] = v
+    # new_checkpoint = {}
+    # for k, v in checkpoint.items():
+    #     new_checkpoint["generator." + k] = v
 
-    for k, v in d_checkpoint.items():
-        new_checkpoint["discriminator." + k] = v
+    # for k, v in d_checkpoint.items():
+    #     new_checkpoint["discriminator." + k] = v
 
-    torch.save(new_checkpoint, "checkpoints/Bert-VITS2/ensemble.pth")
-    exit()
+    # torch.save(new_checkpoint, "checkpoints/Bert-VITS2/ensemble.pth")
+    # exit()
 
     print(model.load_state_dict(checkpoint, strict=False))
 
@@ -672,6 +684,6 @@ if __name__ == "__main__":
     print(o.size(), y_mask.size(), z.size(), z_p.size(), m_p.size(), logs_p.size())
 
     # Save output
-    import soundfile as sf
+    # import soundfile as sf
 
-    sf.write("output.wav", o.squeeze().detach().numpy(), 32000)
+    # sf.write("output.wav", o.squeeze().detach().numpy(), 32000)

+ 1 - 1
fish_speech/models/vits_decoder/modules/vq_encoder.py

@@ -48,7 +48,7 @@ class VQEncoder(nn.Module):
         ), e.unexpected_keys
 
     @torch.no_grad()
-    def forward(self, audios, audio_lengths, use_decoder=False, sr=None):
+    def forward(self, audios, audio_lengths, sr=None):
         mel_spec = self.spec(audios, sample_rate=sr)
 
         if sr is not None: