Bladeren bron

Update vqgan toolchain

Lengyue 2 jaren geleden
bovenliggende
commit
c25c946695
5 gewijzigde bestanden met toevoegingen van 149 en 392 verwijderingen
  1. 1 9
      tools/extract_model.py
  2. 0 129
      tools/infer_vq.py
  3. 0 184
      tools/vqgan/calculate_hubert_features.py
  4. 148 0
      tools/vqgan/inference.py
  5. 0 70
      tools/vqgan/migrate_from_vits.py

+ 1 - 9
tools/extract_model.py

@@ -9,18 +9,10 @@ from loguru import logger
 def main(model_path, output_path):
     if model_path == output_path:
         logger.error("Model path and output path are the same")
-        click.Abort()
+        return
 
     logger.info(f"Loading model from {model_path}")
     state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
-    logger.info("Extracting model")
-
-    state_dict = {
-        state_dict: value
-        for state_dict, value in state_dict.items()
-        if state_dict.startswith("model.")
-    }
-
     torch.save(state_dict, output_path)
     logger.info(f"Model saved to {output_path}")
 

+ 0 - 129
tools/infer_vq.py

@@ -1,129 +0,0 @@
-import librosa
-import numpy as np
-import soundfile as sf
-import torch
-import torch.nn.functional as F
-from einops import rearrange
-from hydra import compose, initialize
-from hydra.utils import instantiate
-from lightning import LightningModule
-from loguru import logger
-from omegaconf import OmegaConf
-
-from fish_speech.models.vqgan.utils import sequence_mask
-
-# register eval resolver
-OmegaConf.register_new_resolver("eval", eval)
-
-
-@torch.no_grad()
-@torch.autocast(device_type="cuda", enabled=True)
-def main():
-    with initialize(version_base="1.3", config_path="../fish_speech/configs"):
-        cfg = compose(config_name="vqgan_single_2x")
-
-    model: LightningModule = instantiate(cfg.model)
-    state_dict = torch.load(
-        "results/vqgan_single_2x/checkpoints/step_000160000.ckpt",
-        map_location=model.device,
-    )["state_dict"]
-    model.load_state_dict(state_dict, strict=True)
-    model.eval()
-    model.cuda()
-    logger.info("Restored model from checkpoint")
-
-    # Load audio
-    audio = librosa.load(
-        "data/Genshin/Chinese/派蒙/vo_WYLQ103_10_paimon_04.wav",
-        sr=model.sampling_rate,
-        mono=True,
-    )[0]
-    audios = torch.from_numpy(audio).to(model.device)[None, None, :]
-    logger.info(
-        f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
-    )
-
-    # VQ Encoder
-    audio_lengths = torch.tensor(
-        [audios.shape[2]], device=model.device, dtype=torch.long
-    )
-
-    features = gt_mels = model.mel_transform(audios, sample_rate=model.sampling_rate)
-
-    if model.downsample is not None:
-        features = model.downsample(features)
-
-    mel_lengths = audio_lengths // model.hop_length
-    feature_lengths = (
-        audio_lengths
-        / model.hop_length
-        / (model.downsample.total_strides if model.downsample is not None else 1)
-    ).long()
-
-    feature_masks = torch.unsqueeze(
-        sequence_mask(feature_lengths, features.shape[2]), 1
-    ).to(gt_mels.dtype)
-    mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
-        gt_mels.dtype
-    )
-
-    # vq_features is 50 hz, need to convert to true mel size
-    text_features = model.mel_encoder(features, feature_masks)
-    _, indices, _ = model.vq_encoder(text_features, feature_masks)
-    print(indices.shape)
-
-    speaker_features = model.speaker_encoder(gt_mels, mel_masks)
-
-    # Restore
-    indices = np.load("codes_0.npy")
-    indices = torch.from_numpy(indices).to(model.device).long()
-    print(indices)
-    # indices = indices.unsqueeze(1).unsqueeze(-1)
-    mel_lengths = indices.shape[1] * (
-        model.downsample.total_strides if model.downsample is not None else 1
-    )
-    mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
-    mel_masks = torch.ones(
-        (1, 1, mel_lengths), device=model.device, dtype=torch.float32
-    )
-
-    print(mel_lengths)
-
-    # Reference speaker
-    # ref_audio = librosa.load(
-    #     "data/StarRail/Chinese/符玄/chapter2_8_fuxuan_104.wav",
-    #     sr=model.sampling_rate,
-    #     mono=True,
-    # )[0]
-    # ref_audios = torch.from_numpy(ref_audio).to(model.device)[None, None, :]
-    # ref_audio_lengths = torch.tensor(
-    #     [ref_audios.shape[2]], device=model.device, dtype=torch.long
-    # )
-    # ref_mels = model.mel_transform(ref_audios, sample_rate=model.sampling_rate)
-    # ref_mel_lengths = ref_audio_lengths // model.hop_length
-    # ref_mel_masks = torch.unsqueeze(
-    #     sequence_mask(ref_mel_lengths, ref_mels.shape[2]), 1
-    # ).to(gt_mels.dtype)
-    # speaker_features = model.speaker_encoder(ref_mels, ref_mel_masks)
-
-    print("indices", indices.shape)
-    text_features = model.vq_encoder.decode(indices)
-
-    logger.info(
-        f"VQ Encoded, indices: {indices.shape} equivalent to "
-        + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[1]):.2f} Hz"
-    )
-
-    text_features = F.interpolate(text_features, size=mel_lengths[0], mode="nearest")
-
-    # Sample mels
-    decoded_mels = model.decoder(text_features, mel_masks, g=speaker_features)
-    fake_audios = model.generator(decoded_mels)
-
-    # Save audio
-    fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
-    sf.write("fake.wav", fake_audio, model.sampling_rate)
-
-
-if __name__ == "__main__":
-    main()

+ 0 - 184
tools/vqgan/calculate_hubert_features.py

@@ -1,184 +0,0 @@
-# This file is used to convert the audio files to text files using the Whisper model.
-# It's mainly used to generate the training data for the VQ model.
-
-import os
-import subprocess as sp
-import sys
-import time
-from datetime import timedelta
-from functools import lru_cache
-from pathlib import Path
-from random import Random
-
-import click
-import numpy as np
-import torch
-import torchaudio
-from loguru import logger
-from transformers import HubertModel
-
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
-
-RANK = int(os.environ.get("SLURM_PROCID", 0))
-WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
-
-logger_format = (
-    "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
-    "<level>{level: <8}</level> | "
-    "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
-    "{extra[rank]} - <level>{message}</level>"
-)
-logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
-logger.remove()
-logger.add(sys.stderr, format=logger_format)
-
-
-@lru_cache(maxsize=1)
-def get_hubert_model():
-    model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-large")
-    model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
-    model = model.half()
-    model.eval()
-
-    logger.info(f"Loaded model")
-    return model
-
-
-def process_batch(files: list[Path], kmeans_centers: torch.Tensor) -> float:
-    model = get_hubert_model()
-
-    wavs = []
-    max_length = total_time = 0
-
-    for file in files:
-        wav, sr = torchaudio.load(file)
-        if wav.shape[0] > 1:
-            wav = wav.mean(dim=0, keepdim=True)
-
-        wav = torchaudio.functional.resample(wav.cuda(), sr, 16000)[0]
-
-        if len(wav) > sr * 60:
-            wav = wav[: sr * 60]
-
-        wavs.append(wav)
-        total_time += len(wav) / sr
-        max_length = max(max_length, len(wav))
-
-    # Pad to max length
-    attention_mask = torch.ones(len(wavs), max_length, dtype=torch.float)
-    feature_lengths = []
-
-    if max_length % 320 != 0:
-        max_length += 320 - max_length % 320
-
-    for i, wav in enumerate(wavs):
-        attention_mask[i, len(wav) :] = 0
-        feature_lengths.append(int(len(wav) / 320))
-        wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
-
-    wavs = torch.stack(wavs, dim=0).half()
-    attention_mask = attention_mask.cuda()
-
-    # Calculate lengths
-    with torch.no_grad():
-        outputs = model(wavs, attention_mask=attention_mask).last_hidden_state
-
-        # Find closest centroids
-        kmeans_centers = kmeans_centers.to(dtype=outputs.dtype, device=outputs.device)
-        distances = torch.cdist(outputs, kmeans_centers)
-        outputs = torch.min(distances, dim=-1)
-        avg_distance = torch.mean(outputs.values)
-
-    # Save to disk
-    outputs = outputs.indices.cpu().numpy()
-
-    for file, length, feature, wav in zip(files, feature_lengths, outputs, wavs):
-        feature = feature[:length]
-
-        # (T,)
-        with open(file.with_suffix(".npy"), "wb") as f:
-            np.save(f, feature)
-
-    return total_time, avg_distance
-
-
-@click.command()
-@click.argument("folder")
-@click.option("--num-workers", default=1)
-@click.option("--kmeans", default="results/hubert-vq-pretrain/kmeans.pt")
-def main(folder: str, num_workers: int, kmeans: str):
-    if num_workers > 1 and WORLD_SIZE != num_workers:
-        assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
-
-        logger.info(f"Spawning {num_workers} workers")
-
-        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
-        if visible_devices is None:
-            visible_devices = list(range(torch.cuda.device_count()))
-        else:
-            visible_devices = visible_devices.split(",")
-
-        processes = []
-        for i in range(num_workers):
-            env = os.environ.copy()
-            env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
-            env["SLURM_PROCID"] = str(i)
-            env["SLURM_NTASKS"] = str(num_workers)
-
-            processes.append(
-                sp.Popen(
-                    [sys.executable] + sys.argv.copy(),
-                    env=env,
-                )
-            )
-
-        for p in processes:
-            p.wait()
-
-        logger.info(f"All workers finished")
-        return
-
-    # This is a worker
-    logger.info(f"Starting worker")
-    files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
-    Random(42).shuffle(files)
-
-    total_files = len(files)
-    files = files[RANK::WORLD_SIZE]
-    logger.info(f"Processing {len(files)}/{total_files} files")
-
-    # Load kmeans
-    kmeans_centers = torch.load(kmeans)["centroids"]
-
-    # Batch size 64
-    total_time = 0
-    begin_time = time.time()
-    processed_files = 0
-    total_distance = 0
-
-    for n_batch, idx in enumerate(range(0, len(files), 32)):
-        batch = files[idx : idx + 32]
-        batch_time, avg_distance = process_batch(batch, kmeans_centers)
-
-        total_time += batch_time
-        processed_files += len(batch)
-        total_distance += avg_distance
-
-        if (n_batch + 1) % 10 == 0:
-            eta = (
-                (time.time() - begin_time)
-                / processed_files
-                * (len(files) - processed_files)
-            )
-            logger.info(
-                f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
-                + f"err {total_distance/(n_batch+1):.2f}, ETA: {timedelta(seconds=round(eta))}s"
-            )
-
-    logger.info(
-        f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
-    )
-
-
-if __name__ == "__main__":
-    main()

+ 148 - 0
tools/vqgan/inference.py

@@ -0,0 +1,148 @@
+from pathlib import Path
+
+import click
+import librosa
+import numpy as np
+import soundfile as sf
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from lightning import LightningModule
+from loguru import logger
+from omegaconf import OmegaConf
+
+from fish_speech.models.vqgan.utils import sequence_mask
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+
+@torch.no_grad()
+@torch.autocast(device_type="cuda", enabled=True)
+@click.command()
+@click.option(
+    "--input-path",
+    "-i",
+    default="data/Genshin/Chinese/派蒙/vo_WYLQ103_10_paimon_04.wav",
+    type=click.Path(exists=True, path_type=Path),
+)
+@click.option(
+    "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
+)
+@click.option("--config-name", "-cfg", default="vqgan_pretrain")
+@click.option(
+    "--checkpoint-path", "-ckpt", default="checkpoints/vqgan/step_000380000_wo.ckpt"
+)
+def main(input_path, output_path, config_name, checkpoint_path):
+    with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+        cfg = compose(config_name=config_name)
+
+    model: LightningModule = instantiate(cfg.model)
+    state_dict = torch.load(
+        checkpoint_path,
+        map_location=model.device,
+    )
+    if "state_dict" in state_dict:
+        state_dict = state_dict["state_dict"]
+    model.load_state_dict(state_dict, strict=True)
+    model.eval()
+    model.cuda()
+    logger.info("Restored model from checkpoint")
+
+    if input_path.suffix == ".wav":
+        logger.info(f"Processing in-place reconstruction of {input_path}")
+        # Load audio
+        audio, _ = librosa.load(
+            input_path,
+            sr=model.sampling_rate,
+            mono=True,
+        )
+        audios = torch.from_numpy(audio).to(model.device)[None, None, :]
+        logger.info(
+            f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
+        )
+
+        # VQ Encoder
+        audio_lengths = torch.tensor(
+            [audios.shape[2]], device=model.device, dtype=torch.long
+        )
+
+        features = gt_mels = model.mel_transform(
+            audios, sample_rate=model.sampling_rate
+        )
+
+        if model.downsample is not None:
+            features = model.downsample(features)
+
+        mel_lengths = audio_lengths // model.hop_length
+        feature_lengths = (
+            audio_lengths
+            / model.hop_length
+            / (model.downsample.total_strides if model.downsample is not None else 1)
+        ).long()
+
+        feature_masks = torch.unsqueeze(
+            sequence_mask(feature_lengths, features.shape[2]), 1
+        ).to(gt_mels.dtype)
+        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
+            gt_mels.dtype
+        )
+
+        # vq_features is 50 hz, need to convert to true mel size
+        text_features = model.mel_encoder(features, feature_masks)
+        _, indices, _ = model.vq_encoder(text_features, feature_masks)
+
+        if indices.ndim == 4 and indices.shape[1] == 1 and indices.shape[3] == 1:
+            indices = indices[:, 0, :, 0]
+        else:
+            logger.error(f"Unknown indices shape: {indices.shape}")
+            return
+
+        logger.info(f"Generated indices of shape {indices.shape}")
+
+        # Save indices
+        np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
+    elif input_path.suffix == ".npy":
+        logger.info(f"Processing precomputed indices from {input_path}")
+        indices = np.load(input_path)
+        indices = torch.from_numpy(indices).to(model.device).long()
+        assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
+    else:
+        raise ValueError(f"Unknown input type: {input_path}")
+
+    # Restore
+    indices = indices.unsqueeze(1).unsqueeze(-1)
+    mel_lengths = indices.shape[2] * (
+        model.downsample.total_strides if model.downsample is not None else 1
+    )
+    mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
+    mel_masks = torch.ones(
+        (1, 1, mel_lengths), device=model.device, dtype=torch.float32
+    )
+
+    text_features = model.vq_encoder.decode(indices)
+
+    logger.info(
+        f"VQ Encoded, indices: {indices.shape} equivalent to "
+        + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
+    )
+
+    text_features = F.interpolate(text_features, size=mel_lengths[0], mode="nearest")
+
+    # Sample mels
+    decoded_mels = model.decoder(text_features, mel_masks)
+    fake_audios = model.generator(decoded_mels)
+    logger.info(
+        f"Generated audio of shape {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
+    )
+
+    # Save audio
+    fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
+    sf.write("fake.wav", fake_audio, model.sampling_rate)
+    logger.info(f"Saved audio to {output_path}")
+
+
+if __name__ == "__main__":
+    main()

+ 0 - 70
tools/vqgan/migrate_from_vits.py

@@ -1,70 +0,0 @@
-import hydra
-import torch
-from loguru import logger
-from omegaconf import DictConfig, OmegaConf
-
-# register eval resolver
-OmegaConf.register_new_resolver("eval", eval)
-
-
-@hydra.main(
-    version_base="1.3",
-    config_path="../../fish_speech/configs",
-    config_name="hubert_vq.yaml",
-)
-def main(cfg: DictConfig):
-    generator_ckpt = cfg.get(
-        "generator_ckpt", "results/hubert-vq-pretrain/rcell/G_23000.pth"
-    )
-    discriminator_ckpt = cfg.get(
-        "discriminator_ckpt", "results/hubert-vq-pretrain/rcell/D_23000.pth"
-    )
-    model = hydra.utils.instantiate(cfg.model)
-
-    # Generator
-    logger.info(f"Model loaded, restoring from {generator_ckpt}")
-    generator_weights = torch.load(generator_ckpt, map_location="cpu")["model"]
-
-    # Decoder
-    generator_state = {
-        k[4:]: v
-        for k, v in generator_weights.items()
-        if k.startswith("dec.") and not k.startswith("dec.cond.")
-    }
-    logger.info(f"Found {len(generator_state)} HiFiGAN weights, restoring...")
-    r = model.generator.dec.load_state_dict(generator_state, strict=False)
-    logger.info(f"Generator weights restored. {r}")
-
-    # Posterior Encoder
-    # encoder_state = {
-    #     k[6:]: v
-    #     for k, v in generator_weights.items()
-    #     if k.startswith("enc_q.") and not k.startswith("enc_q.proj.")
-    # }
-    # logger.info(f"Found {len(encoder_state)} posterior encoder weights, restoring...")
-    # x = model.generator.enc_q.load_state_dict(encoder_state, strict=False)
-    # logger.info(f"Posterior encoder weights restored. {x}")
-
-    # Flow
-    # flow_state = {
-    #     k[5:]: v for k, v in generator_weights.items() if k.startswith("flow.")
-    # }
-    # logger.info(f"Found {len(flow_state)} flow weights, restoring...")
-    # model.flow.load_state_dict(flow_state, strict=True)
-    # logger.info("Flow weights restored.")
-
-    # Discriminator
-    logger.info(f"Model loaded, restoring from {discriminator_ckpt}")
-    discriminator_weights = torch.load(discriminator_ckpt, map_location="cpu")["model"]
-    logger.info(
-        f"Found {len(discriminator_weights)} discriminator weights, restoring..."
-    )
-    model.discriminator.load_state_dict(discriminator_weights, strict=True)
-    logger.info("Discriminator weights restored.")
-
-    torch.save(model.state_dict(), cfg.ckpt_path)
-    logger.info("Done")
-
-
-if __name__ == "__main__":
-    main()