Browse Source

Add 40hz config & infer tool

Lengyue 2 years ago
parent
commit
288741cc92

+ 10 - 11
fish_speech/configs/vq_naive_lfq.yaml → fish_speech/configs/vq_naive_40hz.yaml

@@ -2,17 +2,17 @@ defaults:
   - base
   - _self_
 
-project: vq_naive_lfq
+project: vq_naive_40hz
 
 # Lightning Trainer
 trainer:
   accelerator: gpu
-  devices: [1]
+  devices: 4
   strategy: ddp_find_unused_parameters_true
   gradient_clip_val: 1.0
   gradient_clip_algorithm: 'norm'
   precision: bf16-mixed
-  max_steps: 100_000
+  max_steps: 1_000_000
   val_check_interval: 5000
 
 sample_rate: 22050
@@ -40,8 +40,8 @@ data:
   _target_: fish_speech.datasets.vqgan.VQGANDataModule
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
-  num_workers: 8
-  batch_size: 128
+  num_workers: 4
+  batch_size: 32
   val_batch_size: 16
 
 # Model Configuration
@@ -52,9 +52,9 @@ model:
 
   downsample:
     _target_: fish_speech.models.vq_diffusion.lit_module.ConvDownSample
-    dims: ["${num_mels}", 512, 256]
-    kernel_sizes: [3, 3]
-    strides: [2, 2]
+    dims: ["${num_mels}", 256]
+    kernel_sizes: [3]
+    strides: [2]
 
   mel_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
@@ -71,10 +71,9 @@ model:
   vq_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
     in_channels: 256
-    vq_channels: 14
-    codebook_size: 16384
+    vq_channels: 256
+    codebook_size: 4096
     downsample: 1
-    use_lfq: true
 
   speaker_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder

+ 26 - 37
fish_speech/models/vqgan/lit_module.py

@@ -478,12 +478,7 @@ class VQNaive(L.LightningModule):
             },
         }
 
-    def training_step(self, batch, batch_idx):
-        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-
-        audios = audios.float()
-        audios = audios[:, None, :]
-
+    def vq_encode(self, audios, audio_lengths):
         with torch.no_grad():
             features = gt_mels = self.mel_transform(
                 audios, sample_rate=self.sampling_rate
@@ -506,17 +501,34 @@ class VQNaive(L.LightningModule):
             gt_mels.dtype
         )
 
-        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-
         # vq_features is 50 hz, need to convert to true mel size
         text_features = self.mel_encoder(features, feature_masks)
-        text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
+        text_features, indices, loss_vq = self.vq_encoder(text_features, feature_masks)
+
+        return mel_masks, gt_mels, text_features, indices, loss_vq
+
+    def vq_decode(self, text_features, speaker_features, gt_mels, mel_masks):
         text_features = F.interpolate(
             text_features, size=gt_mels.shape[2], mode="nearest"
         )
 
-        # Sample mels
         decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
+
+        return decoded_mels
+
+    def training_step(self, batch, batch_idx):
+        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+
+        audios = audios.float()
+        audios = audios[:, None, :]
+
+        mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
+            audios, audio_lengths
+        )
+        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
+        decoded_mels = self.vq_decode(
+            text_features, speaker_features, gt_mels, mel_masks
+        )
         loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
         loss = loss_mel + loss_vq
 
@@ -556,36 +568,13 @@ class VQNaive(L.LightningModule):
         audios = audios.float()
         audios = audios[:, None, :]
 
-        features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
-
-        if self.downsample is not None:
-            features = self.downsample(features)
-
-        mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = (
-            audio_lengths
-            / self.hop_length
-            / (self.downsample.total_strides if self.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
-        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
-            gt_mels.dtype
+        mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
+            audios, audio_lengths
         )
-
         speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.mel_encoder(features, feature_masks)
-        text_features, _ = self.vq_encoder(text_features, feature_masks)
-        text_features = F.interpolate(
-            text_features, size=gt_mels.shape[2], mode="nearest"
+        decoded_mels = self.vq_decode(
+            text_features, speaker_features, gt_mels, mel_masks
         )
-
-        # Sample mels
-        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
         fake_audios = self.vocoder(decoded_mels)
 
         mel_loss = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)

+ 3 - 3
fish_speech/models/vqgan/modules/encoders.py

@@ -360,12 +360,12 @@ class VQEncoder(nn.Module):
 
         if self.use_lfq:
             x = self.ln(x.mT)
-            q, _, loss = self.vq(x)
+            q, indices, loss = self.vq(x)
             q = q.mT
         else:
-            q, _, loss = self.vq(x)
+            q, indices, loss = self.vq(x)
 
         x = self.conv_out(q) * x_mask
         x = x[:, :, :x_len]
 
-        return x, loss
+        return x, indices, loss

+ 68 - 0
tools/infer_vq.py

@@ -0,0 +1,68 @@
+import librosa
+import soundfile as sf
+import torch
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from lightning import LightningModule
+from loguru import logger
+from omegaconf import OmegaConf
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+
+@torch.no_grad()
+@torch.autocast(device_type="cuda", enabled=True)
+def main():
+    with initialize(version_base="1.3", config_path="../fish_speech/configs"):
+        cfg = compose(config_name="vq_naive_50hz")
+
+    model: LightningModule = instantiate(cfg.model)
+    state_dict = torch.load(
+        "results/vq_naive_25hz/checkpoints/step_000100000.ckpt",
+        map_location=model.device,
+    )["state_dict"]
+    model.load_state_dict(state_dict, strict=True)
+    model.eval()
+    logger.info("Restored model from checkpoint")
+
+    # Load audio
+    audio = librosa.load("record.wav", sr=model.sampling_rate, mono=True)[0]
+    audios = torch.from_numpy(audio).to(model.device)[None, None, :]
+    logger.info(
+        f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
+    )
+
+    # VQ Encoder
+    audio_lengths = torch.tensor(
+        [audios.shape[2]], device=model.device, dtype=torch.long
+    )
+    mel_masks, gt_mels, text_features, indices, loss_vq = model.vq_encode(
+        audios, audio_lengths
+    )
+    logger.info(
+        f"VQ Encoded, indices: {indices.shape} equavilent to "
+        + f"{1/(audios.shape[2] / model.sampling_rate / indices.shape[1]):.2f} Hz"
+    )
+
+    # VQ Decoder
+    audioa = librosa.load(
+        "data/AiShell/wav/train/S0121/BAC009S0121W0125.wav",
+        sr=model.sampling_rate,
+        mono=True,
+    )[0]
+    audioa = torch.from_numpy(audioa).to(model.device)[None, None, :]
+    mel = model.mel_transform(audioa)
+    mel1_masks = torch.ones([mel.shape[0], 1, mel.shape[2]], device=model.device)
+
+    speaker_features = model.speaker_encoder(mel, mel1_masks)
+    decoded_mels = model.vq_decode(text_features, speaker_features, gt_mels, mel_masks)
+    fake_audios = model.vocoder(decoded_mels)
+
+    # Save audio
+    fake_audio = fake_audios[0, 0].cpu().numpy()
+    sf.write("fake.wav", fake_audio, model.sampling_rate)
+
+
+if __name__ == "__main__":
+    main()