Lengyue 2 лет назад
Родитель
Сommit
d2429c2f1d

+ 3 - 3
fish_speech/configs/hubert_vq.yaml

@@ -22,14 +22,14 @@ win_length: 2048
 # Dataset Configuration
 train_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_train_filelist.txt
+  filelist: data/filelist.split.train
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   slice_frames: 512
 
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_val_filelist.txt
+  filelist: data/filelist.split.valid
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
 
@@ -50,7 +50,7 @@ model:
 
   generator:
     _target_: fish_speech.models.vqgan.modules.models.SynthesizerTrn
-    in_channels: 1024
+    in_channels: 2048
     spec_channels: ${num_mels}
     segment_size: "${eval: '${model.segment_size} // ${hop_length}'}"
     inter_channels: 192

+ 15 - 8
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -51,7 +51,7 @@ model:
 
   text_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
-    in_channels: 2048
+    in_channels: 128
     out_channels: 128
     hidden_channels: 192
     hidden_channels_ffn: 768
@@ -61,21 +61,20 @@ model:
     dropout: 0.1
     use_vae: false
     gin_channels: 512
-    speaker_cond_layer: 2
+    speaker_cond_layer: 0
 
   vq_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
-    in_channels: 1024
-    vq_channels: 1024
-    codebook_size: 2048
-    downsample: 2
-    kmeans_ckpt: results/hubert-vq-pretrain/kmeans.pt
+    in_channels: 128
+    vq_channels: 128
+    codebook_size: 4096
+    downsample: 1
 
   speaker_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
     in_channels: 128
     hidden_channels: 192
-    out_channels: 512
+    out_channels: 128
     num_heads: 2
     num_layers: 4
     p_dropout: 0.1
@@ -104,6 +103,14 @@ model:
     f_min: 40
     f_max: 16000
 
+  feature_mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: 32000
+    n_fft: 2048
+    hop_length: 640
+    win_length: 2048
+    n_mels: 128
+
   optimizer:
     _target_: torch.optim.AdamW
     _partial_: true

+ 31 - 21
fish_speech/models/vq_diffusion/lit_module.py

@@ -28,6 +28,7 @@ class VQDiffusion(L.LightningModule):
         optimizer: Callable,
         lr_scheduler: Callable,
         mel_transform: nn.Module,
+        feature_mel_transform: nn.Module,
         vq_encoder: VQEncoder,
         speaker_encoder: SpeakerEncoder,
         text_encoder: TextEncoder,
@@ -44,6 +45,7 @@ class VQDiffusion(L.LightningModule):
 
         # Generator and discriminators
         self.mel_transform = mel_transform
+        self.feature_mel_transform = feature_mel_transform
         self.noise_scheduler_train = DDIMScheduler(num_train_timesteps=1000)
         self.noise_scheduler_infer = UniPCMultistepScheduler(num_train_timesteps=1000)
 
@@ -91,26 +93,30 @@ class VQDiffusion(L.LightningModule):
         audios = audios[:, None, :]
 
         with torch.no_grad():
-            gt_mels = self.mel_transform(audios)
+            gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
+            features = self.feature_mel_transform(
+                audios, sample_rate=self.sampling_rate
+            )
 
         mel_lengths = audio_lengths // self.hop_length
-
+        feature_lengths = audio_lengths // self.hop_length // 2
         feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[1]), 1
+            sequence_mask(feature_lengths, features.shape[2]), 1
         ).to(gt_mels.dtype)
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
         )
 
         speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-        # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
-
         # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.text_encoder(features, feature_masks, g=speaker_features)
+        text_features = self.text_encoder(features, feature_masks)
+        text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
         text_features = F.interpolate(
             text_features, size=gt_mels.shape[2], mode="nearest"
         )
 
+        text_features = text_features + speaker_features
+
         # Sample noise that we'll add to the images
         normalized_gt_mels = self.normalize_mels(gt_mels)
         noise = torch.randn_like(normalized_gt_mels)
@@ -147,17 +153,17 @@ class VQDiffusion(L.LightningModule):
             sync_dist=True,
         )
 
-        # self.log(
-        #     "train/vq_loss",
-        #     vq_loss,
-        #     on_step=True,
-        #     on_epoch=False,
-        #     prog_bar=True,
-        #     logger=True,
-        #     sync_dist=True,
-        # )
+        self.log(
+            "train/vq_loss",
+            vq_loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
 
-        return noise_loss  # + vq_loss
+        return noise_loss + vq_loss
 
     def validation_step(self, batch: Any, batch_idx: int):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
@@ -166,25 +172,29 @@ class VQDiffusion(L.LightningModule):
         audios = audios.float()
         # features = features.float().mT
         audios = audios[:, None, :]
-        gt_mels = self.mel_transform(audios)
-        mel_lengths = audio_lengths // self.hop_length
+        gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
+        features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
 
+        mel_lengths = audio_lengths // self.hop_length
+        feature_lengths = audio_lengths // self.hop_length // 2
         feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[1]), 1
+            sequence_mask(feature_lengths, features.shape[2]), 1
         ).to(gt_mels.dtype)
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
         )
 
         speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-        # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
 
         # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.text_encoder(features, feature_masks, g=speaker_features)
+        text_features = self.text_encoder(features, feature_masks)
+        text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
         text_features = F.interpolate(
             text_features, size=gt_mels.shape[2], mode="nearest"
         )
 
+        text_features = text_features + speaker_features
+
         # Begin sampling
         sampled_mels = torch.randn_like(gt_mels)
         self.noise_scheduler_infer.set_timesteps(100)

+ 7 - 4
fish_speech/models/vqgan/lit_module.py

@@ -89,7 +89,7 @@ class VQGAN(L.LightningModule):
         audios = audios[:, None, :]
 
         audios = audios.float()
-        features = features.float()
+        # features = features.long()
 
         with torch.no_grad():
             gt_mels = self.mel_transform(audios)
@@ -152,8 +152,11 @@ class VQGAN(L.LightningModule):
             # then 500 steps 0.1
             # then go back to 0
 
-            beta = self.global_step % 1000
-            beta = min(beta, 500) / 500 * 0.1 + 1e-6
+            if self.global_step < 100000:
+                beta = 1e-6
+            else:
+                beta = self.global_step % 1000
+                beta = min(beta, 500) / 500 * 0.1 + 1e-6
 
             loss_gen_all = (
                 loss_mel * 45 + loss_fm + loss_adv + loss_kl * beta
@@ -231,7 +234,7 @@ class VQGAN(L.LightningModule):
         features, feature_lengths = batch["features"], batch["feature_lengths"]
 
         audios = audios.float()
-        features = features.float()
+        # features = features.float()
         audios = audios[:, None, :]
 
         gt_mels = self.mel_transform(audios)

+ 15 - 4
fish_speech/models/vqgan/modules/encoders.py

@@ -28,6 +28,7 @@ class TextEncoder(nn.Module):
         gin_channels=0,
         speaker_cond_layer=0,
         use_vae=True,
+        use_embedding=False,
     ):
         """Text Encoder for VITS model.
 
@@ -45,9 +46,12 @@ class TextEncoder(nn.Module):
         super().__init__()
         self.out_channels = out_channels
         self.hidden_channels = hidden_channels
+        self.use_embedding = use_embedding
 
-        # self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
-        self.proj_in = nn.Embedding(in_channels, hidden_channels)
+        if use_embedding:
+            self.proj_in = nn.Embedding(in_channels, hidden_channels)
+        else:
+            self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
 
         self.encoder = RelativePositionTransformer(
             in_channels=hidden_channels,
@@ -79,7 +83,12 @@ class TextEncoder(nn.Module):
             - x: :math:`[B, T]`
             - x_length: :math:`[B]`
         """
-        x = self.proj_in(x).mT * x_mask
+
+        if self.use_embedding:
+            x = self.proj_in(x.long()).mT * x_mask
+        else:
+            x = self.proj_in(x) * x_mask
+
         x = self.encoder(x, x_mask, g=g)
         x = self.proj_out(x) * x_mask
 
@@ -237,7 +246,9 @@ class VQEncoder(nn.Module):
             in_channels, vq_channels, kernel_size=downsample, stride=downsample
         )
         self.conv_out = nn.Sequential(
-            nn.Upsample(scale_factor=downsample, mode="nearest"),
+            nn.Upsample(scale_factor=downsample, mode="nearest")
+            if downsample > 1
+            else nn.Identity(),
             nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
         )
 

+ 13 - 12
fish_speech/models/vqgan/modules/models.py

@@ -49,12 +49,12 @@ class SynthesizerTrn(nn.Module):
 
         self.segment_size = segment_size
 
-        self.vq = VQEncoder(
-            in_channels=in_channels,
-            vq_channels=in_channels,
-            codebook_size=codebook_size,
-            kmeans_ckpt=kmeans_ckpt,
-        )
+        # self.vq = VQEncoder(
+        #     in_channels=in_channels,
+        #     vq_channels=in_channels,
+        #     codebook_size=codebook_size,
+        #     kmeans_ckpt=kmeans_ckpt,
+        # )
         self.enc_p = TextEncoder(
             in_channels,
             inter_channels,
@@ -105,20 +105,20 @@ class SynthesizerTrn(nn.Module):
         )
 
     def forward(self, x, x_lengths, specs):
-        x = x.mT
+        # x = x.mT
 
-        min_length = min(x.shape[2], specs.shape[2])
+        min_length = min(x.shape[1], specs.shape[2])
         if min_length % 2 != 0:
             min_length -= 1
 
-        x = x[:, :, :min_length]
+        x = x[:, :min_length]
         specs = specs[:, :, :min_length]
         x_lengths = torch.clamp(x_lengths, max=min_length)
 
         spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
             specs.dtype
         )
-        x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
+        x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
 
         g = self.enc_spk(specs, spec_masks)
 
@@ -145,11 +145,12 @@ class SynthesizerTrn(nn.Module):
         )
 
     def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
-        x = x.mT
+        # x = x.mT
         spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
             specs.dtype
         )
-        x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[2]), 1).to(x.dtype)
+        # print(x_lengths, x.shape)
+        x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
         g = self.enc_spk(specs, spec_masks)
         # x, vq_loss = self.vq(x, x_masks)
         z_p, m_p, logs_p, h_text, _ = self.enc_p(

+ 7 - 1
fish_speech/models/vqgan/spectrogram.py

@@ -1,4 +1,5 @@
 import torch
+import torchaudio.functional as F
 from torch import Tensor, nn
 from torchaudio.transforms import MelScale
 
@@ -96,7 +97,12 @@ class LogMelSpectrogram(nn.Module):
     def decompress(self, x: Tensor) -> Tensor:
         return torch.exp(x)
 
-    def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
+    def forward(
+        self, x: Tensor, return_linear: bool = False, sample_rate: int = None
+    ) -> Tensor:
+        if sample_rate is not None and sample_rate != self.sample_rate:
+            x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
+
         linear = self.spectrogram(x)
         x = self.mel_scale(linear)
         x = self.compress(x)