build_dataset.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import re
  2. from collections import defaultdict
  3. from multiprocessing import Pool
  4. import click
  5. import numpy as np
  6. import yaml
  7. from loguru import logger
  8. from tqdm import tqdm
  9. from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
  10. from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
  11. from fish_speech.text import g2p
  12. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  13. def task_generator(config):
  14. with open(config, "r") as f:
  15. config = yaml.load(f, Loader=yaml.FullLoader)
  16. for row in config["datasets"]:
  17. root, source, languages, extension, parent_level = (
  18. row["root"],
  19. row["source"],
  20. row["languages"],
  21. row["extension"],
  22. row["group_parent_level"],
  23. )
  24. # Load the files
  25. files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
  26. grouped_files = defaultdict(list)
  27. for file in files:
  28. if parent_level == 1:
  29. p = file.parent.name
  30. elif parent_level == 2:
  31. p = file.parent.parent.name
  32. else:
  33. raise ValueError(f"Invalid parent level {parent_level}")
  34. grouped_files[p].append(file)
  35. logger.info(f"Found {len(grouped_files)} groups in {root}")
  36. for name, subset in grouped_files.items():
  37. yield name, subset, source, languages, extension
  38. def run_task(task):
  39. name, subset, source, languages, extension = task
  40. # Parse the files
  41. sentences = []
  42. for file in subset:
  43. np_file = file.with_suffix(".npy")
  44. txt_file = file.with_suffix(extension)
  45. if np_file.exists() is False or txt_file.exists() is False:
  46. logger.warning(f"Can't find {np_file} or {txt_file}")
  47. continue
  48. with open(txt_file, "r") as f:
  49. text = f.read().strip()
  50. # Simple cleaning: replace { xxx } and < xxx > with space
  51. text = re.sub(r"\{.*?\}", " ", text)
  52. text = re.sub(r"<.*?>", " ", text)
  53. text = re.sub(r"\s+", " ", text)
  54. try:
  55. phones = [v for _, v in g2p(text, order=languages)]
  56. semantics = np.load(np_file)
  57. except Exception as e:
  58. logger.error(f"Failed to parse {file}: {e}")
  59. continue
  60. if isinstance(semantics, np.ndarray):
  61. semantics = semantics.tolist()
  62. sentences.append(
  63. Sentence(
  64. text=text,
  65. phones=phones,
  66. semantics=[Semantics(values=s) for s in semantics],
  67. )
  68. )
  69. # Pack the sentences
  70. return pack_pb_stream(
  71. TextData(
  72. source=source,
  73. name=name,
  74. languages=languages,
  75. sentences=sentences,
  76. )
  77. )
  78. @click.command()
  79. @click.option(
  80. "--config", type=click.Path(), default="fish_speech/configs/data/finetune.yaml"
  81. )
  82. @click.option("--output", type=click.Path(), default="data/quantized-dataset-ft.protos")
  83. def main(config, output):
  84. dataset_fp = open(output, "wb")
  85. with Pool(16) as p:
  86. for result in tqdm(p.imap_unordered(run_task, task_generator(config))):
  87. dataset_fp.write(result)
  88. dataset_fp.close()
  89. if __name__ == "__main__":
  90. main()