build_dataset.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import itertools
  2. import os
  3. import re
  4. from collections import defaultdict
  5. from functools import partial
  6. from multiprocessing import Pool
  7. from pathlib import Path
  8. import click
  9. import numpy as np
  10. from loguru import logger
  11. from tqdm import tqdm
  12. from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
  13. from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
  14. from tools.file import load_filelist
  15. # To avoid CPU overload
  16. os.environ["MKL_NUM_THREADS"] = "1"
  17. os.environ["OMP_NUM_THREADS"] = "1"
  18. def task_generator_folder(root: Path, text_extension: str):
  19. files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
  20. files = sorted(files)
  21. grouped_files = defaultdict(list)
  22. for file in tqdm(files, desc=f"Grouping {root}"):
  23. p = str(file.parent)
  24. speaker = file.parent.name
  25. try:
  26. if isinstance(text_extension, str):
  27. texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
  28. else:
  29. texts = [
  30. file.with_suffix(ext).read_text(encoding="utf-8")
  31. for ext in text_extension
  32. ]
  33. except Exception as e:
  34. logger.error(f"Failed to read text {file}: {e}")
  35. continue
  36. grouped_files[p].append((speaker, file, texts))
  37. logger.info(
  38. f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
  39. )
  40. for i in grouped_files.values():
  41. subset = [(f, t) for _, f, t in i]
  42. yield i[0][0], subset, "folder"
  43. def task_generator_filelist(filelist):
  44. grouped_files = defaultdict(list)
  45. for filename, speaker, _, text in load_filelist(filelist):
  46. grouped_files[speaker].append((Path(filename), [text]))
  47. logger.info(f"Found {len(grouped_files)} groups in {filelist}")
  48. for speaker, values in grouped_files.items():
  49. yield speaker, values, "filelist"
  50. def run_task(task):
  51. name, subset, source = task
  52. # Parse the files
  53. sentences = []
  54. for file, texts in subset:
  55. np_file = file.with_suffix(".npy")
  56. if np_file.exists() is False:
  57. logger.warning(f"Can't find {np_file}")
  58. continue
  59. new_texts = []
  60. for text in texts:
  61. # Simple cleaning: replace { xxx } and < xxx > with space
  62. text = re.sub(r"\{.*?\}", " ", text)
  63. text = re.sub(r"<.*?>", " ", text)
  64. text = re.sub(r"\s+", " ", text)
  65. new_texts.append(text)
  66. try:
  67. semantics = np.load(np_file)
  68. except Exception as e:
  69. logger.error(f"Failed to parse {file}: {e}")
  70. continue
  71. if isinstance(semantics, np.ndarray):
  72. semantics = semantics.tolist()
  73. sentences.append(
  74. Sentence(
  75. texts=new_texts,
  76. semantics=[Semantics(values=s) for s in semantics],
  77. )
  78. )
  79. # Pack the sentences
  80. return pack_pb_stream(
  81. TextData(
  82. source=source,
  83. name=name,
  84. sentences=sentences,
  85. )
  86. )
  87. @click.command()
  88. @click.option(
  89. "--input",
  90. type=click.Path(path_type=Path),
  91. required=True,
  92. help="A folder containing the dataset or a filelist",
  93. multiple=True,
  94. )
  95. @click.option(
  96. "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
  97. )
  98. @click.option("--num-workers", type=int, default=16)
  99. @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
  100. @click.option(
  101. "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
  102. )
  103. def main(input, output, num_workers, text_extension, shard_size):
  104. generator_fns = []
  105. for f in input:
  106. assert f.exists(), f"{f} not found"
  107. if f.is_dir():
  108. generator_fn = task_generator_folder(f, text_extension)
  109. else:
  110. generator_fn = task_generator_filelist(f)
  111. generator_fns.append(generator_fn)
  112. generator_fn = itertools.chain(*generator_fns)
  113. output.mkdir(parents=True, exist_ok=True)
  114. dataset_fp = None
  115. tar_idx = 0
  116. written_size = 0
  117. with Pool(num_workers) as p:
  118. for result in tqdm(p.imap_unordered(run_task, generator_fn)):
  119. if dataset_fp is None:
  120. dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
  121. dataset_fp.write(result)
  122. written_size += len(result)
  123. if written_size > shard_size * 1024 * 1024:
  124. logger.info(f"Finished writing {tar_idx} shards to {output}")
  125. dataset_fp.close()
  126. dataset_fp = None
  127. written_size = 0
  128. tar_idx += 1
  129. if dataset_fp is not None:
  130. dataset_fp.close()
  131. logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
  132. if __name__ == "__main__":
  133. main()