| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import re
- from collections import defaultdict
- from multiprocessing import Pool
- import click
- import numpy as np
- import yaml
- from loguru import logger
- from tqdm import tqdm
- from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
- from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
- from fish_speech.text import g2p
- from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
- def task_generator(config):
- with open(config, "r") as f:
- config = yaml.load(f, Loader=yaml.FullLoader)
- for row in config["datasets"]:
- root, source, languages, extension, parent_level = (
- row["root"],
- row["source"],
- row["languages"],
- row["extension"],
- row["group_parent_level"],
- )
- # Load the files
- files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
- grouped_files = defaultdict(list)
- for file in files:
- if parent_level == 1:
- p = file.parent.name
- elif parent_level == 2:
- p = file.parent.parent.name
- else:
- raise ValueError(f"Invalid parent level {parent_level}")
- grouped_files[p].append(file)
- logger.info(f"Found {len(grouped_files)} groups in {root}")
- for name, subset in grouped_files.items():
- yield name, subset, source, languages, extension
- def run_task(task):
- name, subset, source, languages, extension = task
- # Parse the files
- sentences = []
- for file in subset:
- np_file = file.with_suffix(".npy")
- txt_file = file.with_suffix(extension)
- if np_file.exists() is False or txt_file.exists() is False:
- logger.warning(f"Can't find {np_file} or {txt_file}")
- continue
- with open(txt_file, "r") as f:
- text = f.read().strip()
- # Simple cleaning: replace { xxx } and < xxx > with space
- text = re.sub(r"\{.*?\}", " ", text)
- text = re.sub(r"<.*?>", " ", text)
- text = re.sub(r"\s+", " ", text)
- try:
- phones = [v for _, v in g2p(text, order=languages)]
- semantics = np.load(np_file)
- except Exception as e:
- logger.error(f"Failed to parse {file}: {e}")
- continue
- if isinstance(semantics, np.ndarray):
- semantics = semantics.tolist()
- sentences.append(
- Sentence(
- text=text,
- phones=phones,
- semantics=[Semantics(values=s) for s in semantics],
- )
- )
- # Pack the sentences
- return pack_pb_stream(
- TextData(
- source=source,
- name=name,
- languages=languages,
- sentences=sentences,
- )
- )
- @click.command()
- @click.option(
- "--config", type=click.Path(), default="fish_speech/configs/data/finetune.yaml"
- )
- @click.option("--output", type=click.Path(), default="data/quantized-dataset-ft.protos")
- def main(config, output):
- dataset_fp = open(output, "wb")
- with Pool(16) as p:
- for result in tqdm(p.imap_unordered(run_task, task_generator(config))):
- dataset_fp.write(result)
- dataset_fp.close()
- if __name__ == "__main__":
- main()
|