build_dataset.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import re
  2. from collections import defaultdict
  3. from multiprocessing import Pool
  4. import numpy as np
  5. from loguru import logger
  6. from tqdm import tqdm
  7. from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
  8. from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
  9. from fish_speech.text import g2p
  10. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  11. # Define datasets
  12. DATASETS = [
  13. ("data/StarRail/Chinese", "StarRail", ["ZH", "EN"], ".lab", 1),
  14. ("data/StarRail/English", "StarRail", ["EN"], ".lab", 1),
  15. ("data/StarRail/Japanese", "StarRail", ["JP", "EN"], ".lab", 1),
  16. ("data/Genshin/Chinese", "Genshin", ["ZH", "EN"], ".lab", 1),
  17. ("data/Genshin/English", "Genshin", ["EN"], ".lab", 1),
  18. ("data/Genshin/Japanese", "Genshin", ["JP", "EN"], ".lab", 1),
  19. ("data/LibriTTS_R", "LibriTTS_R", ["EN"], ".normalized.txt", 2),
  20. ("data/WenetSpeech", "WenetSpeech", ["ZH", "EN"], ".txt", 1),
  21. ]
  22. def task_generator():
  23. for root, source, languages, extension, parent_level in DATASETS:
  24. # Load the files
  25. files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
  26. grouped_files = defaultdict(list)
  27. for file in files:
  28. if parent_level == 1:
  29. p = file.parent.name
  30. elif parent_level == 2:
  31. p = file.parent.parent.name
  32. else:
  33. raise ValueError(f"Invalid parent level {parent_level}")
  34. grouped_files[p].append(file)
  35. logger.info(f"Found {len(grouped_files)} groups in {root}")
  36. for name, subset in grouped_files.items():
  37. yield name, subset, source, languages, extension
  38. def run_task(task):
  39. name, subset, source, languages, extension = task
  40. # Parse the files
  41. sentences = []
  42. for file in subset:
  43. np_file = file.with_suffix(".npy")
  44. txt_file = file.with_suffix(extension)
  45. if np_file.exists() is False or txt_file.exists() is False:
  46. continue
  47. with open(txt_file, "r") as f:
  48. text = f.read().strip()
  49. # Simple cleaning: replace { xxx } and < xxx > with space
  50. text = re.sub(r"\{.*?\}", " ", text)
  51. text = re.sub(r"<.*?>", " ", text)
  52. text = re.sub(r"\s+", " ", text)
  53. try:
  54. phones = [v for _, v in g2p(text, order=languages)]
  55. semantics = np.load(np_file)
  56. except Exception as e:
  57. logger.error(f"Failed to parse {file}: {e}")
  58. continue
  59. if isinstance(semantics, np.ndarray):
  60. semantics = semantics.tolist()
  61. sentences.append(
  62. Sentence(
  63. text=text,
  64. phones=phones,
  65. semantics=[Semantics(values=s) for s in semantics],
  66. )
  67. )
  68. # Pack the sentences
  69. return pack_pb_stream(
  70. TextData(
  71. source=source,
  72. name=name,
  73. languages=languages,
  74. sentences=sentences,
  75. )
  76. )
  77. def main():
  78. dataset_fp = open("data/quantized-dataset-1205.protos", "wb")
  79. with Pool(16) as p:
  80. for result in tqdm(p.imap_unordered(run_task, task_generator())):
  81. dataset_fp.write(result)
  82. dataset_fp.close()
  83. if __name__ == "__main__":
  84. main()