| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- import re
- from collections import defaultdict
- from multiprocessing import Pool
- import numpy as np
- from loguru import logger
- from tqdm import tqdm
- from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
- from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
- from fish_speech.text import g2p
- from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
- # Define datasets
- DATASETS = [
- ("data/StarRail/Chinese", "StarRail", ["ZH", "EN"], ".lab", 1),
- ("data/StarRail/English", "StarRail", ["EN"], ".lab", 1),
- ("data/StarRail/Japanese", "StarRail", ["JP", "EN"], ".lab", 1),
- ("data/Genshin/Chinese", "Genshin", ["ZH", "EN"], ".lab", 1),
- ("data/Genshin/English", "Genshin", ["EN"], ".lab", 1),
- ("data/Genshin/Japanese", "Genshin", ["JP", "EN"], ".lab", 1),
- ("data/LibriTTS_R", "LibriTTS_R", ["EN"], ".normalized.txt", 2),
- ("data/WenetSpeech", "WenetSpeech", ["ZH", "EN"], ".txt", 1),
- ]
- def task_generator():
- for root, source, languages, extension, parent_level in DATASETS:
- # Load the files
- files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
- grouped_files = defaultdict(list)
- for file in files:
- if parent_level == 1:
- p = file.parent.name
- elif parent_level == 2:
- p = file.parent.parent.name
- else:
- raise ValueError(f"Invalid parent level {parent_level}")
- grouped_files[p].append(file)
- logger.info(f"Found {len(grouped_files)} groups in {root}")
- for name, subset in grouped_files.items():
- yield name, subset, source, languages, extension
- def run_task(task):
- name, subset, source, languages, extension = task
- # Parse the files
- sentences = []
- for file in subset:
- np_file = file.with_suffix(".npy")
- txt_file = file.with_suffix(extension)
- if np_file.exists() is False or txt_file.exists() is False:
- continue
- with open(txt_file, "r") as f:
- text = f.read().strip()
- # Simple cleaning: replace { xxx } and < xxx > with space
- text = re.sub(r"\{.*?\}", " ", text)
- text = re.sub(r"<.*?>", " ", text)
- text = re.sub(r"\s+", " ", text)
- try:
- phones = [v for _, v in g2p(text, order=languages)]
- semantics = np.load(np_file)
- except Exception as e:
- logger.error(f"Failed to parse {file}: {e}")
- continue
- if isinstance(semantics, np.ndarray):
- semantics = semantics.tolist()
- sentences.append(
- Sentence(
- text=text,
- phones=phones,
- semantics=[Semantics(values=s) for s in semantics],
- )
- )
- # Pack the sentences
- return pack_pb_stream(
- TextData(
- source=source,
- name=name,
- languages=languages,
- sentences=sentences,
- )
- )
- def main():
- dataset_fp = open("data/quantized-dataset-1205.protos", "wb")
- with Pool(16) as p:
- for result in tqdm(p.imap_unordered(run_task, task_generator())):
- dataset_fp.write(result)
- dataset_fp.close()
- if __name__ == "__main__":
- main()
|