build_dataset.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import os
  2. import re
  3. from collections import defaultdict
  4. from multiprocessing import Pool
  5. from pathlib import Path
  6. import click
  7. import numpy as np
  8. import yaml
  9. from loguru import logger
  10. from tqdm import tqdm
  11. from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
  12. from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
  13. from fish_speech.text import g2p
  14. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
  15. def task_generator_yaml(config):
  16. with open(config, "r") as f:
  17. config = yaml.load(f, Loader=yaml.FullLoader)
  18. for row in config["datasets"]:
  19. root, source, languages, extension, parent_level = (
  20. row["root"],
  21. row["source"],
  22. row["languages"],
  23. row["extension"],
  24. row["group_parent_level"],
  25. )
  26. # Load the files
  27. files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
  28. grouped_files = defaultdict(list)
  29. for file in files:
  30. if parent_level == 1:
  31. p = file.parent.name
  32. elif parent_level == 2:
  33. p = file.parent.parent.name
  34. else:
  35. raise ValueError(f"Invalid parent level {parent_level}")
  36. grouped_files[p].append(file)
  37. logger.info(f"Found {len(grouped_files)} groups in {root}")
  38. for name, subset in grouped_files.items():
  39. yield name, subset, source, languages, extension, None
  40. def task_generator_filelist(filelist):
  41. grouped_files = defaultdict(list)
  42. for filename, speaker, languages, text in load_filelist(filelist):
  43. if speaker in grouped_files:
  44. assert (
  45. languages == grouped_files[speaker][0][2]
  46. ), f"Speaker {speaker} has different languages"
  47. grouped_files[speaker].append((Path(filename), text, languages))
  48. logger.info(f"Found {len(grouped_files)} groups in {filelist}")
  49. for speaker, (filename, txt, languages) in grouped_files.items():
  50. yield speaker, filename, "filelist", languages, None, txt
  51. def run_task(task):
  52. name, subset, source, languages, extension, text = task
  53. # Parse the files
  54. sentences = []
  55. for file in subset:
  56. np_file = file.with_suffix(".npy")
  57. if np_file.exists() is False:
  58. logger.warning(f"Can't find {np_file}")
  59. continue
  60. if text is None:
  61. txt_file = file.with_suffix(extension)
  62. if txt_file.exists() is False:
  63. logger.warning(f"Can't find {txt_file}")
  64. continue
  65. with open(txt_file, "r") as f:
  66. text = f.read().strip()
  67. # Simple cleaning: replace { xxx } and < xxx > with space
  68. text = re.sub(r"\{.*?\}", " ", text)
  69. text = re.sub(r"<.*?>", " ", text)
  70. text = re.sub(r"\s+", " ", text)
  71. try:
  72. phones = [v for _, v in g2p(text, order=languages)]
  73. semantics = np.load(np_file)
  74. except Exception as e:
  75. logger.error(f"Failed to parse {file}: {e}")
  76. continue
  77. if isinstance(semantics, np.ndarray):
  78. semantics = semantics.tolist()
  79. sentences.append(
  80. Sentence(
  81. text=text,
  82. phones=phones,
  83. semantics=[Semantics(values=s) for s in semantics],
  84. )
  85. )
  86. # Pack the sentences
  87. return pack_pb_stream(
  88. TextData(
  89. source=source,
  90. name=name,
  91. languages=languages,
  92. sentences=sentences,
  93. )
  94. )
  95. @click.command()
  96. @click.option(
  97. "--config", type=click.Path(), default="fish_speech/configs/data/finetune.yaml"
  98. )
  99. @click.option("--output", type=click.Path(), default="data/quantized-dataset-ft.protos")
  100. @click.option("--filelist", type=click.Path(), default=None)
  101. @click.option("--num_worker", type=int, default=16)
  102. def main(config, output, filelist, num_worker):
  103. dataset_fp = open(output, "wb")
  104. generator_fn = task_generator_yaml if filelist is None else task_generator_filelist
  105. with Pool(num_worker) as p:
  106. for result in tqdm(p.imap_unordered(run_task, generator_fn(config, filelist))):
  107. dataset_fp.write(result)
  108. dataset_fp.close()
  109. if __name__ == "__main__":
  110. main()