build_dataset.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import itertools
  2. import os
  3. import re
  4. from collections import defaultdict
  5. from functools import partial
  6. from multiprocessing import Pool
  7. from pathlib import Path
  8. import click
  9. import numpy as np
  10. from loguru import logger
  11. from tqdm import tqdm
  12. from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
  13. from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
  14. from fish_speech.utils.file import load_filelist
  15. # To avoid CPU overload
  16. os.environ["MKL_NUM_THREADS"] = "1"
  17. os.environ["OMP_NUM_THREADS"] = "1"
  18. def task_generator_folder(root: Path, text_extension: str):
  19. files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
  20. files = sorted(files)
  21. grouped_files = defaultdict(list)
  22. for file in tqdm(files, desc=f"Grouping {root}"):
  23. p = str(file.parent)
  24. try:
  25. if isinstance(text_extension, str):
  26. texts = [file.with_suffix(text_extension).read_text()]
  27. else:
  28. texts = [file.with_suffix(ext).read_text() for ext in text_extension]
  29. except Exception as e:
  30. logger.error(f"Failed to read text {file}: {e}")
  31. continue
  32. grouped_files[p].append((file, texts))
  33. logger.info(
  34. f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
  35. )
  36. for name, subset in grouped_files.items():
  37. yield name, subset, "folder"
  38. def task_generator_filelist(filelist):
  39. grouped_files = defaultdict(list)
  40. for filename, speaker, _, text in load_filelist(filelist):
  41. grouped_files[speaker].append((Path(filename), [text]))
  42. logger.info(f"Found {len(grouped_files)} groups in {filelist}")
  43. for speaker, values in grouped_files.items():
  44. yield speaker, values, "filelist"
  45. def run_task(task):
  46. name, subset, source = task
  47. # Parse the files
  48. sentences = []
  49. for file in subset:
  50. file, texts = file
  51. np_file = file.with_suffix(".npy")
  52. if np_file.exists() is False:
  53. logger.warning(f"Can't find {np_file}")
  54. continue
  55. new_texts = []
  56. for text in texts:
  57. # Simple cleaning: replace { xxx } and < xxx > with space
  58. text = re.sub(r"\{.*?\}", " ", text)
  59. text = re.sub(r"<.*?>", " ", text)
  60. text = re.sub(r"\s+", " ", text)
  61. new_texts.append(text)
  62. try:
  63. semantics = np.load(np_file)
  64. except Exception as e:
  65. logger.error(f"Failed to parse {file}: {e}")
  66. continue
  67. if isinstance(semantics, np.ndarray):
  68. semantics = semantics.tolist()
  69. sentences.append(
  70. Sentence(
  71. texts=new_texts,
  72. semantics=[Semantics(values=s) for s in semantics],
  73. )
  74. )
  75. # Pack the sentences
  76. return pack_pb_stream(
  77. TextData(
  78. source=source,
  79. name=name,
  80. sentences=sentences,
  81. )
  82. )
  83. @click.command()
  84. @click.option(
  85. "--input",
  86. type=click.Path(path_type=Path),
  87. required=True,
  88. help="A folder containing the dataset or a filelist",
  89. multiple=True,
  90. )
  91. @click.option(
  92. "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
  93. )
  94. @click.option("--num-workers", type=int, default=16)
  95. @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
  96. @click.option(
  97. "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
  98. )
  99. def main(input, output, num_workers, text_extension, shard_size):
  100. generator_fns = []
  101. for f in input:
  102. assert f.exists(), f"{f} not found"
  103. if f.is_dir():
  104. generator_fn = task_generator_folder(f, text_extension)
  105. else:
  106. generator_fn = task_generator_filelist(f)
  107. generator_fns.append(generator_fn)
  108. generator_fn = itertools.chain(*generator_fns)
  109. output.mkdir(parents=True, exist_ok=True)
  110. dataset_fp = None
  111. tar_idx = 0
  112. written_size = 0
  113. with Pool(num_workers) as p:
  114. for result in tqdm(p.imap_unordered(run_task, generator_fn)):
  115. if dataset_fp is None:
  116. dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
  117. dataset_fp.write(result)
  118. written_size += len(result)
  119. if written_size > shard_size * 1024 * 1024:
  120. logger.info(f"Finished writing {tar_idx} shards to {output}")
  121. dataset_fp.close()
  122. dataset_fp = None
  123. written_size = 0
  124. tar_idx += 1
  125. if dataset_fp is not None:
  126. dataset_fp.close()
  127. logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
  128. if __name__ == "__main__":
  129. main()