build_dataset.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import itertools
  2. import os
  3. import re
  4. from collections import defaultdict
  5. from functools import partial
  6. from multiprocessing import Pool
  7. from pathlib import Path
  8. import click
  9. import numpy as np
  10. from loguru import logger
  11. from tqdm import tqdm
  12. from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
  13. from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
  14. from fish_speech.utils.file import load_filelist
  15. # To avoid CPU overload
  16. os.environ["MKL_NUM_THREADS"] = "1"
  17. os.environ["OMP_NUM_THREADS"] = "1"
  18. def task_generator_folder(root: Path, text_extension: str):
  19. files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
  20. files = sorted(files)
  21. grouped_files = defaultdict(list)
  22. for file in tqdm(files, desc=f"Grouping {root}"):
  23. p = str(file.parent)
  24. speaker = file.parent.name
  25. try:
  26. if isinstance(text_extension, str):
  27. texts = [file.with_suffix(text_extension).read_text()]
  28. else:
  29. texts = [file.with_suffix(ext).read_text() for ext in text_extension]
  30. except Exception as e:
  31. logger.error(f"Failed to read text {file}: {e}")
  32. continue
  33. grouped_files[p].append((speaker, file, texts))
  34. logger.info(
  35. f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
  36. )
  37. for speaker, file, texts in grouped_files.values():
  38. yield speaker, (file, texts), "folder"
  39. def task_generator_filelist(filelist):
  40. grouped_files = defaultdict(list)
  41. for filename, speaker, _, text in load_filelist(filelist):
  42. grouped_files[speaker].append((Path(filename), [text]))
  43. logger.info(f"Found {len(grouped_files)} groups in {filelist}")
  44. for speaker, values in grouped_files.items():
  45. yield speaker, values, "filelist"
  46. def run_task(task):
  47. name, subset, source = task
  48. # Parse the files
  49. sentences = []
  50. for file in subset:
  51. file, texts = file
  52. np_file = file.with_suffix(".npy")
  53. if np_file.exists() is False:
  54. logger.warning(f"Can't find {np_file}")
  55. continue
  56. new_texts = []
  57. for text in texts:
  58. # Simple cleaning: replace { xxx } and < xxx > with space
  59. text = re.sub(r"\{.*?\}", " ", text)
  60. text = re.sub(r"<.*?>", " ", text)
  61. text = re.sub(r"\s+", " ", text)
  62. new_texts.append(text)
  63. try:
  64. semantics = np.load(np_file)
  65. except Exception as e:
  66. logger.error(f"Failed to parse {file}: {e}")
  67. continue
  68. if isinstance(semantics, np.ndarray):
  69. semantics = semantics.tolist()
  70. sentences.append(
  71. Sentence(
  72. texts=new_texts,
  73. semantics=[Semantics(values=s) for s in semantics],
  74. )
  75. )
  76. # Pack the sentences
  77. return pack_pb_stream(
  78. TextData(
  79. source=source,
  80. name=name,
  81. sentences=sentences,
  82. )
  83. )
  84. @click.command()
  85. @click.option(
  86. "--input",
  87. type=click.Path(path_type=Path),
  88. required=True,
  89. help="A folder containing the dataset or a filelist",
  90. multiple=True,
  91. )
  92. @click.option(
  93. "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
  94. )
  95. @click.option("--num-workers", type=int, default=16)
  96. @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
  97. @click.option(
  98. "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
  99. )
  100. def main(input, output, num_workers, text_extension, shard_size):
  101. generator_fns = []
  102. for f in input:
  103. assert f.exists(), f"{f} not found"
  104. if f.is_dir():
  105. generator_fn = task_generator_folder(f, text_extension)
  106. else:
  107. generator_fn = task_generator_filelist(f)
  108. generator_fns.append(generator_fn)
  109. generator_fn = itertools.chain(*generator_fns)
  110. output.mkdir(parents=True, exist_ok=True)
  111. dataset_fp = None
  112. tar_idx = 0
  113. written_size = 0
  114. with Pool(num_workers) as p:
  115. for result in tqdm(p.imap_unordered(run_task, generator_fn)):
  116. if dataset_fp is None:
  117. dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
  118. dataset_fp.write(result)
  119. written_size += len(result)
  120. if written_size > shard_size * 1024 * 1024:
  121. logger.info(f"Finished writing {tar_idx} shards to {output}")
  122. dataset_fp.close()
  123. dataset_fp = None
  124. written_size = 0
  125. tar_idx += 1
  126. if dataset_fp is not None:
  127. dataset_fp.close()
  128. logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
  129. if __name__ == "__main__":
  130. main()