build_dataset.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import re
  2. from collections import defaultdict
  3. from multiprocessing import Pool
  4. from pathlib import Path
  5. import click
  6. import numpy as np
  7. import yaml
  8. from loguru import logger
  9. from tqdm import tqdm
  10. from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
  11. from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
  12. from fish_speech.text import g2p
  13. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
  14. def task_generator_yaml(config):
  15. with open(config, "r") as f:
  16. config = yaml.load(f, Loader=yaml.FullLoader)
  17. for row in config["datasets"]:
  18. root, source, languages, extension, parent_level = (
  19. row["root"],
  20. row["source"],
  21. row["languages"],
  22. row["extension"],
  23. row["group_parent_level"],
  24. )
  25. # Load the files
  26. files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
  27. grouped_files = defaultdict(list)
  28. for file in files:
  29. if parent_level == 1:
  30. p = file.parent.name
  31. elif parent_level == 2:
  32. p = file.parent.parent.name
  33. else:
  34. raise ValueError(f"Invalid parent level {parent_level}")
  35. grouped_files[p].append(file)
  36. logger.info(f"Found {len(grouped_files)} groups in {root}")
  37. for name, subset in grouped_files.items():
  38. yield name, subset, source, languages, extension
  39. def task_generator_filelist(filelist):
  40. grouped_files = defaultdict(list)
  41. for filename, speaker, languages, text in load_filelist(filelist):
  42. grouped_files[speaker].append((Path(filename), text, languages))
  43. logger.info(f"Found {len(grouped_files)} groups in {filelist}")
  44. for speaker, values in grouped_files.items():
  45. yield speaker, values, "filelist", languages, None
  46. def run_task(task):
  47. name, subset, source, languages, extension = task
  48. # Parse the files
  49. sentences = []
  50. for file in subset:
  51. if isinstance(file, tuple):
  52. file, text, languages = file
  53. else:
  54. text = None
  55. np_file = file.with_suffix(".npy")
  56. if np_file.exists() is False:
  57. logger.warning(f"Can't find {np_file}")
  58. continue
  59. if text is None:
  60. txt_file = file.with_suffix(extension)
  61. if txt_file.exists() is False:
  62. logger.warning(f"Can't find {txt_file}")
  63. continue
  64. with open(txt_file, "r") as f:
  65. text = f.read().strip()
  66. # Simple cleaning: replace { xxx } and < xxx > with space
  67. text = re.sub(r"\{.*?\}", " ", text)
  68. text = re.sub(r"<.*?>", " ", text)
  69. text = re.sub(r"\s+", " ", text)
  70. try:
  71. phones = [v for _, v in g2p(text, order=languages)]
  72. semantics = np.load(np_file)
  73. except Exception as e:
  74. logger.error(f"Failed to parse {file}: {e}")
  75. continue
  76. if isinstance(semantics, np.ndarray):
  77. semantics = semantics.tolist()
  78. sentences.append(
  79. Sentence(
  80. text=text,
  81. phones=phones,
  82. semantics=[Semantics(values=s) for s in semantics],
  83. )
  84. )
  85. # Pack the sentences
  86. return pack_pb_stream(
  87. TextData(
  88. source=source,
  89. name=name,
  90. languages=languages,
  91. sentences=sentences,
  92. )
  93. )
  94. @click.command()
  95. @click.option(
  96. "--config", type=click.Path(), default="fish_speech/configs/data/finetune.yaml"
  97. )
  98. @click.option("--output", type=click.Path(), default="data/quantized-dataset-ft.protos")
  99. @click.option("--filelist", type=click.Path(), default=None)
  100. @click.option("--num-workers", type=int, default=16)
  101. def main(config, output, filelist, num_workers):
  102. dataset_fp = open(output, "wb")
  103. generator_fn = (
  104. task_generator_yaml(config)
  105. if filelist is None
  106. else task_generator_filelist(filelist)
  107. )
  108. with Pool(num_workers) as p:
  109. for result in tqdm(p.imap_unordered(run_task, generator_fn)):
  110. dataset_fp.write(result)
  111. dataset_fp.close()
  112. if __name__ == "__main__":
  113. main()