Просмотр исходного кода

Update extract vq & text parser

Lengyue 2 лет назад
Родитель
Сommit
a6e9f6871b

+ 1 - 1
fish_speech/configs/vqgan.yaml

@@ -51,7 +51,7 @@ model:
   freeze_hifigan: false
   freeze_hifigan: false
 
 
   downsample:
   downsample:
-    _target_: fish_speech.models.vq_diffusion.lit_module.ConvDownSample
+    _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
     dims: ["${num_mels}", 512, 256]
     dims: ["${num_mels}", 512, 256]
     kernel_sizes: [3, 3]
     kernel_sizes: [3, 3]
     strides: [2, 2]
     strides: [2, 2]

+ 15 - 0
fish_speech/models/text2semantic/modules.py

@@ -537,12 +537,22 @@ class FishSpeechTransformer(nn.Module):
         **sampling_kwargs,
         **sampling_kwargs,
     ):
     ):
         new_tokens, new_probs = [], []
         new_tokens, new_probs = [], []
+        # Sliding context window
+        batch_size = 1
+        back_map = torch.zeros(
+            [batch_size, 1], device=cur_token.device, dtype=torch.long
+        )
 
 
         for i in range(num_new_tokens):
         for i in range(num_new_tokens):
             next_token, next_prob = self.sample_decoder(
             next_token, next_prob = self.sample_decoder(
                 cur_token, context, input_pos, **sampling_kwargs
                 cur_token, context, input_pos, **sampling_kwargs
             )
             )
 
 
+            # index_map = torch.arange(6, device=cur_token.device)
+            # index_map = back_map[:, -1:] + index_map.repeat(batch_size, 1)
+            # add = torch.arange(batch_size, device=index_map.device).unsqueeze(1) #N, 1
+            # index_map = index_map + add * t_length
+
             input_pos += 1
             input_pos += 1
             new_tokens.append(next_token.clone())
             new_tokens.append(next_token.clone())
             callback(new_tokens[-1])
             callback(new_tokens[-1])
@@ -555,6 +565,11 @@ class FishSpeechTransformer(nn.Module):
 
 
         return new_tokens, new_probs
         return new_tokens, new_probs
 
 
+    def compile(self):
+        self.sampler_decoder = torch.compile(
+            self.sample_decoder, mode="reduce-overhead", fullgraph=True
+        )
+
     @torch.no_grad()
     @torch.no_grad()
     def inference(self, inputs, prompt=None, max_new_tokens=1024, **sampling_kwargs):
     def inference(self, inputs, prompt=None, max_new_tokens=1024, **sampling_kwargs):
         # inputs: (B, T)
         # inputs: (B, T)

+ 7 - 0
fish_speech/models/vqgan/modules/encoders.py

@@ -1,6 +1,7 @@
 from math import log2
 from math import log2
 from typing import Optional
 from typing import Optional
 
 
+import numpy as np
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
@@ -335,3 +336,9 @@ class VQEncoder(nn.Module):
         x = x[:, :, :x_len]
         x = x[:, :, :x_len]
 
 
         return x, indices, loss
         return x, indices, loss
+
+    def decode(self, indices):
+        q = self.vq.get_output_from_indices(indices).mT
+        x = self.conv_out(q)
+
+        return x

+ 5 - 0
fish_speech/text/japanese.py

@@ -51,6 +51,7 @@ def g2p(text):
     text = symbols_to_japanese(text)
     text = symbols_to_japanese(text)
     sentences = re.split(_japanese_marks, text)
     sentences = re.split(_japanese_marks, text)
     marks = re.findall(_japanese_marks, text)
     marks = re.findall(_japanese_marks, text)
+    ct = text
     text = []
     text = []
     for i, sentence in enumerate(sentences):
     for i, sentence in enumerate(sentences):
         if re.match(_japanese_characters, sentence):
         if re.match(_japanese_characters, sentence):
@@ -60,4 +61,8 @@ def g2p(text):
         if i < len(marks):
         if i < len(marks):
             text += [marks[i].replace(" ", "")]
             text += [marks[i].replace(" ", "")]
 
 
+    # Clean empty strings
+    text = [t for t in text if t.strip() != ""]
+    text = ["-" if t == "pau" else t for t in text]
+
     return text
     return text

+ 5 - 1
fish_speech/text/parser.py

@@ -72,6 +72,7 @@ SYMBOLS_MAPPING = {
     "—": "-",
     "—": "-",
     "~": "-",
     "~": "-",
     "~": "-",
     "~": "-",
+    "・": "-",
     "「": "'",
     "「": "'",
     "」": "'",
     "」": "'",
     ";": ",",
     ";": ",",
@@ -196,6 +197,9 @@ def segments_to_phones(
 
 
     for segment in segments:
     for segment in segments:
         for phone in segment.phones:
         for phone in segment.phones:
+            if phone.strip() == "":
+                continue
+
             q0 = (segment.language, phone)
             q0 = (segment.language, phone)
             q1 = (None, phone)
             q1 = (None, phone)
 
 
@@ -206,7 +210,7 @@ def segments_to_phones(
                 phones.append(q1)
                 phones.append(q1)
                 ids.append(symbols_to_id[q1])
                 ids.append(symbols_to_id[q1])
             else:
             else:
-                raise ValueError(f"Unknown phone: {segment.language} {phone}")
+                raise ValueError(f"Unknown phone: {segment.language} - {phone} -")
 
 
     return phones, ids
     return phones, ids
 
 

+ 24 - 5
tools/infer_vq.py

@@ -3,6 +3,7 @@ import numpy as np
 import soundfile as sf
 import soundfile as sf
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
+from einops import rearrange
 from hydra import compose, initialize
 from hydra import compose, initialize
 from hydra.utils import instantiate
 from hydra.utils import instantiate
 from lightning import LightningModule
 from lightning import LightningModule
@@ -23,7 +24,7 @@ def main():
 
 
     model: LightningModule = instantiate(cfg.model)
     model: LightningModule = instantiate(cfg.model)
     state_dict = torch.load(
     state_dict = torch.load(
-        "results/vqgan/checkpoints/step_000110000.ckpt",
+        "checkpoints/vqgan/step_000380000.ckpt",
         map_location=model.device,
         map_location=model.device,
     )["state_dict"]
     )["state_dict"]
     model.load_state_dict(state_dict, strict=True)
     model.load_state_dict(state_dict, strict=True)
@@ -32,7 +33,7 @@ def main():
     logger.info("Restored model from checkpoint")
     logger.info("Restored model from checkpoint")
 
 
     # Load audio
     # Load audio
-    audio = librosa.load("0.wav", sr=model.sampling_rate, mono=True)[0]
+    audio = librosa.load("test.wav", sr=model.sampling_rate, mono=True)[0]
     audios = torch.from_numpy(audio).to(model.device)[None, None, :]
     audios = torch.from_numpy(audio).to(model.device)[None, None, :]
     logger.info(
     logger.info(
         f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
         f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
@@ -64,14 +65,32 @@ def main():
 
 
     # vq_features is 50 hz, need to convert to true mel size
     # vq_features is 50 hz, need to convert to true mel size
     text_features = model.mel_encoder(features, feature_masks)
     text_features = model.mel_encoder(features, feature_masks)
-    text_features, indices, _ = model.vq_encoder(text_features, feature_masks)
+    _, indices, _ = model.vq_encoder(text_features, feature_masks)
+    print(indices.shape)
 
 
+    # Restore
+    indices = np.load(
+        "data/LibriTTS_R/train-clean-100/7226/86964/7226_86964_000012_000003.npy"
+    )
+    indices = torch.from_numpy(indices).to(model.device)
+    indices = indices.unsqueeze(1).unsqueeze(-1)
+    mel_lengths = indices.shape[2] * (
+        model.downsample.total_strides if model.downsample is not None else 1
+    )
+    mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
+    mel_masks = torch.ones(
+        (1, 1, mel_lengths), device=model.device, dtype=torch.float32
+    )
+
+    print(mel_lengths)
+
+    text_features = model.vq_encoder.decode(indices)
     logger.info(
     logger.info(
         f"VQ Encoded, indices: {indices.shape} equivalent to "
         f"VQ Encoded, indices: {indices.shape} equivalent to "
-        + f"{1/(audios.shape[2] / model.sampling_rate / indices.shape[2]):.2f} Hz"
+        + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
     )
     )
 
 
-    text_features = F.interpolate(text_features, size=gt_mels.shape[2], mode="nearest")
+    text_features = F.interpolate(text_features, size=mel_lengths[0], mode="nearest")
 
 
     # Sample mels
     # Sample mels
     decoded_mels = model.decoder(text_features, mel_masks)
     decoded_mels = model.decoder(text_features, mel_masks)

+ 111 - 0
tools/llama/build_dataset.py

@@ -0,0 +1,111 @@
+import json
+import re
+from collections import defaultdict
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from typing import Union
+
+import numpy as np
+from loguru import logger
+from tqdm import tqdm
+
+from fish_speech.text import g2p
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+# Define datasets
+DATASETS = [
+    ("data/StarRail/Chinese", "StarRail", ["ZH", "EN"], ".lab", 1),
+    ("data/StarRail/English", "StarRail", ["EN"], ".lab", 1),
+    ("data/StarRail/Japanese", "StarRail", ["JP", "EN"], ".lab", 1),
+    ("data/Genshin/Chinese", "Genshin", ["ZH", "EN"], ".lab", 1),
+    ("data/Genshin/English", "Genshin", ["EN"], ".lab", 1),
+    ("data/Genshin/Japanese", "Genshin", ["JP", "EN"], ".lab", 1),
+    ("data/LibriTTS_R", "LibriTTS_R", ["EN"], ".normalized.txt", 2),
+    ("data/WenetSpeech", "WenetSpeech", ["ZH", "EN"], ".txt", 1),
+]
+
+
+@dataclass
+class Sentence:
+    text: str
+    phones: list[str]
+    # Support multiple codebooks
+    semantics: Union[list[int], list[list[int]]]
+
+
+@dataclass
+class PackedSentences:
+    source: str
+    name: str
+    languages: list[str]
+    sentences: list[Sentence]
+
+
+dataset_fp = open("data/quantized-dataset-1205.json", "w")
+
+for root, source, languages, extension, parent_level in DATASETS:
+    # Load the files
+    exts = extension.split(".")
+    files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
+    logger.info(f"Found {len(files)} files in {root}")
+
+    grouped_files = defaultdict(list)
+    for file in files:
+        if parent_level == 1:
+            p = file.parent.name
+        elif parent_level == 2:
+            p = file.parent.parent.name
+        else:
+            raise ValueError(f"Invalid parent level {parent_level}")
+
+        grouped_files[p].append(file)
+
+    for name, subset in tqdm(grouped_files.items()):
+        # Parse the files
+        sentences = []
+        for file in subset:
+            np_file = file.with_suffix(".npy")
+            txt_file = file.with_suffix(extension)
+            if np_file.exists() is False or txt_file.exists() is False:
+                continue
+
+            with open(txt_file, "r") as f:
+                text = f.read().strip()
+
+            # Simple cleaning: replace { xxx } and < xxx > with space
+            text = re.sub(r"\{.*?\}", " ", text)
+            text = re.sub(r"<.*?>", " ", text)
+            text = re.sub(r"\s+", " ", text)
+
+            try:
+                phones = [v for _, v in g2p(text, order=languages)]
+                semantics = np.load(np_file)
+            except Exception as e:
+                logger.error(f"Failed to parse {file}: {e}")
+                continue
+
+            if isinstance(semantics, np.ndarray):
+                semantics = semantics.tolist()
+
+            sentences.append(
+                Sentence(
+                    text=text,
+                    phones=phones,
+                    semantics=semantics,
+                )
+            )
+
+        # Pack the sentences
+        packed_sentences = PackedSentences(
+            source=source,
+            name=name,
+            languages=languages,
+            sentences=sentences,
+        )
+
+        dataset_fp.write(
+            json.dumps(asdict(packed_sentences), ensure_ascii=False) + "\n"
+        )
+
+
+dataset_fp.close()

+ 200 - 0
tools/vqgan/extract_vq.py

@@ -0,0 +1,200 @@
+import os
+import subprocess as sp
+import sys
+import time
+from datetime import timedelta
+from functools import lru_cache
+from pathlib import Path
+from random import Random
+
+import click
+import numpy as np
+import torch
+import torchaudio
+from einops import rearrange
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from lightning import LightningModule
+from loguru import logger
+from omegaconf import OmegaConf
+
+from fish_speech.models.vqgan.utils import sequence_mask
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+# This file is used to convert the audio files to text files using the Whisper model.
+# It's mainly used to generate the training data for the VQ model.
+
+
+RANK = int(os.environ.get("SLURM_PROCID", 0))
+WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
+
+logger_format = (
+    "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
+    "<level>{level: <8}</level> | "
+    "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
+    "{extra[rank]} - <level>{message}</level>"
+)
+logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
+logger.remove()
+logger.add(sys.stderr, format=logger_format)
+
+
+@lru_cache(maxsize=1)
+def get_model():
+    with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
+        cfg = compose(config_name="vqgan")
+
+    model: LightningModule = instantiate(cfg.model)
+    state_dict = torch.load(
+        "checkpoints/vqgan/step_000380000.ckpt",
+        map_location=model.device,
+    )["state_dict"]
+    model.load_state_dict(state_dict, strict=True)
+    model.eval()
+    model.cuda()
+    logger.info("Restored model from checkpoint")
+
+    logger.info(f"Loaded model")
+    return model
+
+
+def process_batch(files: list[Path]) -> float:
+    model = get_model()
+
+    wavs = []
+    audio_lengths = []
+    max_length = total_time = 0
+
+    for file in files:
+        wav, sr = torchaudio.load(file)
+        if wav.shape[0] > 1:
+            wav = wav.mean(dim=0, keepdim=True)
+
+        wav = torchaudio.functional.resample(wav.cuda(), sr, model.sampling_rate)[0]
+        wavs.append(wav)
+        total_time += len(wav) / model.sampling_rate
+        max_length = max(max_length, len(wav))
+        audio_lengths.append(len(wav))
+
+    # Pad to max length
+    for i, wav in enumerate(wavs):
+        wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
+
+    audios = torch.stack(wavs, dim=0)[:, None]
+    audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
+
+    # Calculate lengths
+    with torch.no_grad():
+        # VQ Encoder
+        features = gt_mels = model.mel_transform(
+            audios, sample_rate=model.sampling_rate
+        )
+
+        if model.downsample is not None:
+            features = model.downsample(features)
+
+        feature_lengths = (
+            audio_lengths
+            / model.hop_length
+            / (model.downsample.total_strides if model.downsample is not None else 1)
+        ).long()
+
+        feature_masks = torch.unsqueeze(
+            sequence_mask(feature_lengths, features.shape[2]), 1
+        ).to(gt_mels.dtype)
+
+        # vq_features is 50 hz, need to convert to true mel size
+        text_features = model.mel_encoder(features, feature_masks)
+        _, indices, _ = model.vq_encoder(text_features, feature_masks)
+        indices = indices.squeeze(-1)
+        indices = rearrange(indices, "c b t -> b c t")
+
+    # Save to disk
+    outputs = indices.cpu().numpy()
+
+    for file, length, feature, audio in zip(files, feature_lengths, outputs, audios):
+        feature = feature[:, :length]
+
+        # (T,)
+        with open(file.with_suffix(".npy"), "wb") as f:
+            np.save(f, feature)
+
+    return total_time
+
+
+@click.command()
+@click.argument("folder")
+@click.option("--num-workers", default=1)
+def main(folder: str, num_workers: int):
+    if num_workers > 1 and WORLD_SIZE != num_workers:
+        assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
+
+        logger.info(f"Spawning {num_workers} workers")
+
+        visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+        if visible_devices is None:
+            visible_devices = list(range(torch.cuda.device_count()))
+        else:
+            visible_devices = visible_devices.split(",")
+
+        processes = []
+        for i in range(num_workers):
+            env = os.environ.copy()
+            env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
+            env["SLURM_PROCID"] = str(i)
+            env["SLURM_NTASKS"] = str(num_workers)
+
+            processes.append(
+                sp.Popen(
+                    [sys.executable] + sys.argv.copy(),
+                    env=env,
+                )
+            )
+
+        for p in processes:
+            p.wait()
+
+        logger.info(f"All workers finished")
+        return
+
+    # This is a worker
+    logger.info(f"Starting worker")
+    files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
+    Random(42).shuffle(files)
+
+    total_files = len(files)
+    files = files[RANK::WORLD_SIZE]
+    logger.info(f"Processing {len(files)}/{total_files} files")
+
+    # Batch size 64
+    total_time = 0
+    begin_time = time.time()
+    processed_files = 0
+
+    for n_batch, idx in enumerate(range(0, len(files), 32)):
+        batch = files[idx : idx + 32]
+        batch_time = process_batch(batch)
+
+        total_time += batch_time
+        processed_files += len(batch)
+
+        if (n_batch + 1) % 10 == 0:
+            eta = (
+                (time.time() - begin_time)
+                / processed_files
+                * (len(files) - processed_files)
+            )
+            logger.info(
+                f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
+                + f"ETA: {timedelta(seconds=round(eta))}s"
+            )
+
+    logger.info(
+        f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
+    )
+
+
+if __name__ == "__main__":
+    main()