build_dataset.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import itertools
  2. import os
  3. import re
  4. from collections import defaultdict
  5. from functools import partial
  6. from multiprocessing import Pool
  7. from pathlib import Path
  8. import click
  9. import numpy as np
  10. from loguru import logger
  11. from tqdm import tqdm
  12. from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
  13. from fish_speech.datasets.protos.text_data_stream import load_filelist, pack_pb_stream
  14. # To avoid CPU overload
  15. os.environ["MKL_NUM_THREADS"] = "1"
  16. os.environ["OMP_NUM_THREADS"] = "1"
  17. def task_generator_folder(root: Path, text_extension: str):
  18. files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
  19. files = sorted(files)
  20. grouped_files = defaultdict(list)
  21. for file in tqdm(files, desc=f"Grouping {root}"):
  22. p = str(file.parent)
  23. speaker = file.parent.name
  24. try:
  25. if isinstance(text_extension, str):
  26. texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
  27. else:
  28. texts = [
  29. file.with_suffix(ext).read_text(encoding="utf-8")
  30. for ext in text_extension
  31. ]
  32. except Exception as e:
  33. logger.error(f"Failed to read text {file}: {e}")
  34. continue
  35. grouped_files[p].append((speaker, file, texts))
  36. logger.info(
  37. f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
  38. )
  39. for i in grouped_files.values():
  40. subset = [(f, t) for _, f, t in i]
  41. yield i[0][0], subset, "folder"
  42. def task_generator_filelist(filelist):
  43. grouped_files = defaultdict(list)
  44. for filename, speaker, _, text in load_filelist(filelist):
  45. grouped_files[speaker].append((Path(filename), [text]))
  46. logger.info(f"Found {len(grouped_files)} groups in {filelist}")
  47. for speaker, values in grouped_files.items():
  48. yield speaker, values, "filelist"
  49. def run_task(task):
  50. name, subset, source = task
  51. # Parse the files
  52. sentences = []
  53. for file, texts in subset:
  54. np_file = file.with_suffix(".npy")
  55. if np_file.exists() is False:
  56. logger.warning(f"Can't find {np_file}")
  57. continue
  58. new_texts = []
  59. for text in texts:
  60. # Simple cleaning: replace { xxx } and < xxx > with space
  61. text = re.sub(r"\{.*?\}", " ", text)
  62. text = re.sub(r"<.*?>", " ", text)
  63. text = re.sub(r"\s+", " ", text)
  64. new_texts.append(text)
  65. try:
  66. semantics = np.load(np_file)
  67. except Exception as e:
  68. logger.error(f"Failed to parse {file}: {e}")
  69. continue
  70. if isinstance(semantics, np.ndarray):
  71. semantics = semantics.tolist()
  72. sentences.append(
  73. Sentence(
  74. texts=new_texts,
  75. semantics=[Semantics(values=s) for s in semantics],
  76. )
  77. )
  78. # Pack the sentences
  79. return pack_pb_stream(
  80. TextData(
  81. source=source,
  82. name=name,
  83. sentences=sentences,
  84. )
  85. )
  86. @click.command()
  87. @click.option(
  88. "--input",
  89. type=click.Path(path_type=Path),
  90. required=True,
  91. help="A folder containing the dataset or a filelist",
  92. multiple=True,
  93. )
  94. @click.option(
  95. "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
  96. )
  97. @click.option("--num-workers", type=int, default=16)
  98. @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
  99. @click.option(
  100. "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
  101. )
  102. def main(input, output, num_workers, text_extension, shard_size):
  103. generator_fns = []
  104. for f in input:
  105. assert f.exists(), f"{f} not found"
  106. if f.is_dir():
  107. generator_fn = task_generator_folder(f, text_extension)
  108. else:
  109. generator_fn = task_generator_filelist(f)
  110. generator_fns.append(generator_fn)
  111. generator_fn = itertools.chain(*generator_fns)
  112. output.mkdir(parents=True, exist_ok=True)
  113. dataset_fp = None
  114. tar_idx = 0
  115. written_size = 0
  116. with Pool(num_workers) as p:
  117. for result in tqdm(p.imap_unordered(run_task, generator_fn)):
  118. if dataset_fp is None:
  119. dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
  120. dataset_fp.write(result)
  121. written_size += len(result)
  122. if written_size > shard_size * 1024 * 1024:
  123. logger.info(f"Finished writing {tar_idx} shards to {output}")
  124. dataset_fp.close()
  125. dataset_fp = None
  126. written_size = 0
  127. tar_idx += 1
  128. if dataset_fp is not None:
  129. dataset_fp.close()
  130. logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
  131. if __name__ == "__main__":
  132. main()