build_dataset.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import os
  2. import re
  3. from collections import defaultdict
  4. from multiprocessing import Pool
  5. from pathlib import Path
  6. import click
  7. import numpy as np
  8. import yaml
  9. from loguru import logger
  10. from tqdm import tqdm
  11. from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
  12. from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
  13. from fish_speech.text import g2p
  14. from fish_speech.utils.file import load_filelist
  15. # To avoid CPU overload
  16. os.environ["MKL_NUM_THREADS"] = "1"
  17. os.environ["OMP_NUM_THREADS"] = "1"
  18. def task_generator_yaml(config):
  19. with open(config, "r") as f:
  20. config = yaml.load(f, Loader=yaml.FullLoader)
  21. for row in config["datasets"]:
  22. root, source, languages, extension, parent_level = (
  23. row["root"],
  24. row["source"],
  25. row["languages"],
  26. row["extension"],
  27. row["group_parent_level"],
  28. )
  29. if isinstance(parent_level, int):
  30. parent_level = [parent_level]
  31. # Load the files
  32. files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
  33. files = sorted(files)
  34. grouped_files = defaultdict(list)
  35. for file in tqdm(files, desc=f"Grouping {root}"):
  36. all_parents = []
  37. pointer = file
  38. while pointer.parent.name:
  39. all_parents.append(pointer.parent.name)
  40. pointer = pointer.parent
  41. ps = []
  42. for level in parent_level:
  43. ps.append(all_parents[level - 1])
  44. p = "-".join(ps)
  45. grouped_files[p].append(file)
  46. logger.info(f"Found {len(grouped_files)} groups in {root}")
  47. for name, subset in grouped_files.items():
  48. yield name, subset, source, languages, extension
  49. def task_generator_filelist(filelist):
  50. grouped_files = defaultdict(list)
  51. for filename, speaker, languages, text in load_filelist(filelist):
  52. grouped_files[speaker].append((Path(filename), text, languages))
  53. logger.info(f"Found {len(grouped_files)} groups in {filelist}")
  54. for speaker, values in grouped_files.items():
  55. yield speaker, values, "filelist", languages, None
  56. def run_task(task):
  57. name, subset, source, languages, extension = task
  58. # Parse the files
  59. sentences = []
  60. for file in subset:
  61. if isinstance(file, tuple):
  62. file, text, languages = file
  63. else:
  64. text = None
  65. np_file = file.with_suffix(".npy")
  66. if np_file.exists() is False:
  67. logger.warning(f"Can't find {np_file}")
  68. continue
  69. if text is None:
  70. txt_file = file.with_suffix(extension)
  71. if txt_file.exists() is False:
  72. logger.warning(f"Can't find {txt_file}")
  73. continue
  74. with open(txt_file, "r") as f:
  75. text = f.read().strip()
  76. # Simple cleaning: replace { xxx } and < xxx > with space
  77. text = re.sub(r"\{.*?\}", " ", text)
  78. text = re.sub(r"<.*?>", " ", text)
  79. text = re.sub(r"\s+", " ", text)
  80. try:
  81. phones = [v for _, v in g2p(text, order=languages)]
  82. semantics = np.load(np_file)
  83. except Exception as e:
  84. logger.error(f"Failed to parse {file}: {e}")
  85. continue
  86. if isinstance(semantics, np.ndarray):
  87. semantics = semantics.tolist()
  88. sentences.append(
  89. Sentence(
  90. text=text,
  91. phones=phones,
  92. semantics=[Semantics(values=s) for s in semantics],
  93. )
  94. )
  95. # Pack the sentences
  96. return pack_pb_stream(
  97. TextData(
  98. source=source,
  99. name=name,
  100. languages=languages,
  101. sentences=sentences,
  102. )
  103. )
  104. @click.command()
  105. @click.option(
  106. "--config", type=click.Path(), default="fish_speech/configs/data/finetune.yaml"
  107. )
  108. @click.option("--output", type=click.Path(), default="data/quantized-dataset-ft.protos")
  109. @click.option("--filelist", type=click.Path(), default=None)
  110. @click.option("--num-workers", type=int, default=16)
  111. def main(config, output, filelist, num_workers):
  112. dataset_fp = open(output, "wb")
  113. generator_fn = (
  114. task_generator_yaml(config)
  115. if filelist is None
  116. else task_generator_filelist(filelist)
  117. )
  118. with Pool(num_workers) as p:
  119. for result in tqdm(p.imap_unordered(run_task, generator_fn)):
  120. dataset_fp.write(result)
  121. dataset_fp.close()
  122. if __name__ == "__main__":
  123. main()