build_dataset.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import os
  2. import re
  3. from collections import defaultdict
  4. from multiprocessing import Pool
  5. from pathlib import Path
  6. import click
  7. import numpy as np
  8. import yaml
  9. from loguru import logger
  10. from tqdm import tqdm
  11. from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
  12. from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
  13. from fish_speech.text import g2p
  14. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  15. def task_generator(config, filelist):
  16. with open(config, "r") as f:
  17. config = yaml.load(f, Loader=yaml.FullLoader)
  18. for row in config["datasets"]:
  19. root, source, languages, extension, parent_level = (
  20. row["root"],
  21. row["source"],
  22. row["languages"],
  23. row["extension"],
  24. row["group_parent_level"],
  25. )
  26. # Load the files
  27. if filelist:
  28. with open(filelist, "r", encoding="utf-8") as f:
  29. # files = [Path(line..strip().split("|")[0]) for line in f]
  30. files = set()
  31. countSame = 0
  32. countNotFound = 0
  33. for line in f.readlines():
  34. file = Path(line.strip().split("|")[0])
  35. if file in files:
  36. print(f"重复音频文本:{line}")
  37. countSame += 1
  38. continue
  39. if not os.path.isfile(file):
  40. # 过滤数据集错误:不存在对应音频
  41. print(f"没有找到对应的音频:{file}")
  42. countNotFound += 1
  43. continue
  44. files.add(file)
  45. else:
  46. files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
  47. grouped_files = defaultdict(list)
  48. for file in files:
  49. if parent_level == 1:
  50. p = file.parent.name
  51. elif parent_level == 2:
  52. p = file.parent.parent.name
  53. else:
  54. raise ValueError(f"Invalid parent level {parent_level}")
  55. grouped_files[p].append(file)
  56. logger.info(f"Found {len(grouped_files)} groups in {root}")
  57. for name, subset in grouped_files.items():
  58. yield name, subset, source, languages, extension
  59. def run_task(task):
  60. name, subset, source, languages, extension = task
  61. # Parse the files
  62. sentences = []
  63. for file in subset:
  64. np_file = file.with_suffix(".npy")
  65. txt_file = file.with_suffix(extension)
  66. if np_file.exists() is False or txt_file.exists() is False:
  67. logger.warning(f"Can't find {np_file} or {txt_file}")
  68. continue
  69. with open(txt_file, "r") as f:
  70. text = f.read().strip()
  71. # Simple cleaning: replace { xxx } and < xxx > with space
  72. text = re.sub(r"\{.*?\}", " ", text)
  73. text = re.sub(r"<.*?>", " ", text)
  74. text = re.sub(r"\s+", " ", text)
  75. try:
  76. phones = [v for _, v in g2p(text, order=languages)]
  77. semantics = np.load(np_file)
  78. except Exception as e:
  79. logger.error(f"Failed to parse {file}: {e}")
  80. continue
  81. if isinstance(semantics, np.ndarray):
  82. semantics = semantics.tolist()
  83. sentences.append(
  84. Sentence(
  85. text=text,
  86. phones=phones,
  87. semantics=[Semantics(values=s) for s in semantics],
  88. )
  89. )
  90. # Pack the sentences
  91. return pack_pb_stream(
  92. TextData(
  93. source=source,
  94. name=name,
  95. languages=languages,
  96. sentences=sentences,
  97. )
  98. )
  99. @click.command()
  100. @click.option(
  101. "--config", type=click.Path(), default="fish_speech/configs/data/finetune.yaml"
  102. )
  103. @click.option("--output", type=click.Path(), default="data/quantized-dataset-ft.protos")
  104. @click.option("--filelist", type=click.Path(), default=None)
  105. @click.option("--num_worker", type=int, default=16)
  106. def main(config, output, filelist, num_worker):
  107. dataset_fp = open(output, "wb")
  108. with Pool(num_worker) as p:
  109. for result in tqdm(
  110. p.imap_unordered(run_task, task_generator(config, filelist))
  111. ):
  112. dataset_fp.write(result)
  113. dataset_fp.close()
  114. if __name__ == "__main__":
  115. main()