Explorar el Código

Support VITS inference

Lengyue hace 1 año
padre
commit
fcae7a9ef3

+ 1 - 3
fish_speech/models/vits_decoder/lit_module.py

@@ -1,6 +1,4 @@
-import itertools
-from dataclasses import dataclass
-from typing import Any, Callable, Literal, Optional
+from typing import Any, Callable
 
 import lightning as L
 import torch

+ 31 - 34
fish_speech/models/vits_decoder/modules/models.py

@@ -542,21 +542,18 @@ class SynthesizerTrn(nn.Module):
         text_lengths,
         noise_scale=0.5,
     ):
-        y_mask = torch.unsqueeze(
-            commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
-        ).to(gt_specs.dtype)
-        ge = self.ref_enc(gt_specs * y_mask, y_mask)
         quantized = self.vq(audio, audio_lengths)
-
-        x, m_p, logs_p, y_mask = self.enc_p(
-            quantized, audio_lengths, text, text_lengths, ge
+        quantized_lengths = audio_lengths // 512
+        ge = self.encode_ref(gt_specs, gt_spec_lengths)
+
+        return self.decode(
+            quantized,
+            quantized_lengths,
+            text,
+            text_lengths,
+            noise_scale=noise_scale,
+            ge=ge,
         )
-        z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
-
-        z = self.flow(z_p, y_mask, g=ge, reverse=True)
-
-        o = self.dec(z * y_mask, g=ge)
-        return o
 
     @torch.no_grad()
     def infer_posterior(
@@ -574,35 +571,35 @@ class SynthesizerTrn(nn.Module):
         return o
 
     @torch.no_grad()
-    def decode(self, codes, text, refer, noise_scale=0.5):
-        # TODO: not tested yet
-
-        ge = None
-        if refer is not None:
-            refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
-            refer_mask = torch.unsqueeze(
-                commons.sequence_mask(refer_lengths, refer.size(2)), 1
-            ).to(refer.dtype)
-            ge = self.ref_enc(refer * refer_mask, refer_mask)
-
-        y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
-        text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
-
-        quantized = self.quantizer.decode(codes)
-        quantized = F.interpolate(
-            quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
-        )
-
+    def decode(
+        self,
+        quantized,
+        quantized_lengths,
+        text,
+        text_lengths,
+        noise_scale=0.5,
+        ge=None,
+    ):
         x, m_p, logs_p, y_mask = self.enc_p(
-            quantized, y_lengths, text, text_lengths, ge
+            quantized, quantized_lengths, text, text_lengths, ge
         )
         z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
 
         z = self.flow(z_p, y_mask, g=ge, reverse=True)
 
-        o = self.dec((z * y_mask)[:, :, :], g=ge)
+        o = self.dec(z * y_mask, g=ge)
+
         return o
 
+    @torch.no_grad()
+    def encode_ref(self, gt_specs, gt_spec_lengths):
+        y_mask = torch.unsqueeze(
+            commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
+        ).to(gt_specs.dtype)
+        ge = self.ref_enc(gt_specs * y_mask, y_mask)
+
+        return ge
+
 
 if __name__ == "__main__":
     import librosa

+ 16 - 0
fish_speech/models/vits_decoder/modules/vq_encoder.py

@@ -1,8 +1,11 @@
+import math
+
 import torch
 from torch import nn
 
 from fish_speech.models.vqgan.modules.fsq import DownsampleFiniteScalarQuantize
 from fish_speech.models.vqgan.modules.wavenet import WaveNet
+from fish_speech.models.vqgan.utils import sequence_mask
 from fish_speech.utils.spectrogram import LogMelSpectrogram
 
 
@@ -67,3 +70,16 @@ class VQEncoder(nn.Module):
         encoded_features = self.quantizer(encoded_features).z * mel_masks_float_conv
 
         return encoded_features
+
+    @torch.no_grad()
+    def indicies_to_vq_features(
+        self,
+        indices,
+        feature_lengths,
+    ):
+        factor = math.prod(self.quantizer.downsample_factor)
+        mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+        z = self.quantizer.decode(indices) * mel_masks_float_conv
+
+        return z

+ 153 - 0
tools/vits_decoder/inference.py

@@ -0,0 +1,153 @@
+from pathlib import Path
+
+import click
+import hydra
+import librosa
+import numpy as np
+import soundfile as sf
+import torch
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from lightning import LightningModule
+from loguru import logger
+from omegaconf import OmegaConf
+from transformers import AutoTokenizer
+
+from fish_speech.utils.file import AUDIO_EXTENSIONS
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+
+def load_model(config_name, checkpoint_path, device="cuda"):
+    hydra.core.global_hydra.GlobalHydra.instance().clear()
+    with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+        cfg = compose(config_name=config_name)
+
+    model: LightningModule = instantiate(cfg.model)
+    state_dict = torch.load(
+        checkpoint_path,
+        map_location=model.device,
+    )
+
+    if "state_dict" in state_dict:
+        state_dict = state_dict["state_dict"]
+
+    model.load_state_dict(state_dict, strict=False)
+    model.eval()
+    model.to(device)
+    logger.info("Restored model from checkpoint")
+
+    return model
+
+
+@torch.no_grad()
+@click.command()
+@click.option(
+    "--input-path",
+    "-i",
+    default="test.npy",
+    type=click.Path(exists=True, path_type=Path),
+)
+@click.option(
+    "--reference-path",
+    "-r",
+    type=click.Path(exists=True, path_type=Path),
+    default=None,
+)
+@click.option(
+    "--text",
+    type=str,
+    default="-",
+)
+@click.option(
+    "--tokenizer",
+    type=str,
+    default="fishaudio/fish-speech-1",
+)
+@click.option(
+    "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
+)
+@click.option("--config-name", "-cfg", default="vits_decoder")
+@click.option(
+    "--checkpoint-path",
+    "-ckpt",
+    default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+)
+@click.option(
+    "--device",
+    "-d",
+    default="cuda",
+)
+def main(
+    input_path,
+    reference_path,
+    text,
+    tokenizer,
+    output_path,
+    config_name,
+    checkpoint_path,
+    device,
+):
+    model = load_model(config_name, checkpoint_path, device=device)
+
+    assert input_path.suffix == ".npy", f"Expected .npy file, got {input_path.suffix}"
+
+    logger.info(f"Processing precomputed indices from {input_path}")
+    indices = np.load(input_path)
+    indices = torch.from_numpy(indices).to(model.device).long()
+    assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
+
+    # Extract reference audio
+    if reference_path is not None:
+        assert (
+            reference_path.suffix in AUDIO_EXTENSIONS
+        ), f"Expected audio file, got {reference_path.suffix}"
+        reference_audio, sr = librosa.load(reference_path, sr=model.sampling_rate)
+        reference_audio = torch.from_numpy(reference_audio).to(model.device).float()
+        reference_spec = model.spec_transform(reference_audio[None])
+        reference_embedding = model.generator.encode_ref(
+            reference_spec,
+            torch.tensor([reference_spec.shape[-1]], device=model.device),
+        )
+        logger.info(
+            f"Loaded reference audio from {reference_path}, shape: {reference_audio.shape}"
+        )
+    else:
+        reference_embedding = torch.zeros(
+            1, model.generator.gin_channels, 1, device=model.device
+        )
+        logger.info("No reference audio provided, use zero embedding")
+
+    # Extract text
+    tokenizer = AutoTokenizer.from_pretrained(tokenizer)
+    encoded_text = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
+    logger.info(f"Encoded text: {encoded_text.shape}")
+
+    # Restore
+    feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
+    quantized = model.generator.vq.indicies_to_vq_features(
+        indices=indices[None], feature_lengths=feature_lengths
+    )
+    logger.info(f"Restored VQ features: {quantized.shape}")
+
+    # Decode
+    fake_audios = model.generator.decode(
+        quantized,
+        torch.tensor([quantized.shape[-1]], device=model.device),
+        encoded_text,
+        torch.tensor([encoded_text.shape[-1]], device=model.device),
+        ge=reference_embedding,
+    )
+    logger.info(
+        f"Generated audio: {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
+    )
+
+    # Save audio
+    fake_audio = fake_audios[0, 0].float().cpu().numpy()
+    sf.write(output_path, fake_audio, model.sampling_rate)
+    logger.info(f"Saved audio to {output_path}")
+
+
+if __name__ == "__main__":
+    main()