create_train_split.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import math
  2. import os
  3. from pathlib import Path
  4. from random import Random
  5. import click
  6. from tqdm import tqdm
  7. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  8. @click.command()
  9. @click.argument("root", type=click.Path(exists=True, path_type=Path))
  10. @click.option("--val-ratio", type=float, default=0.2)
  11. @click.option("--val-count", type=int, default=None)
  12. @click.option("--filelist", default=None, type=Path)
  13. def main(root, val_ratio, val_count, filelist):
  14. if filelist:
  15. with open(filelist, "r", encoding="utf-8") as f:
  16. # files = [Path(line..strip().split("|")[0]) for line in f]
  17. files = set()
  18. countSame = 0
  19. countNotFound = 0
  20. for line in f.readlines():
  21. file = Path(line.strip().split("|")[0])
  22. if file in files:
  23. print(f"重复音频文本:{line}")
  24. countSame += 1
  25. continue
  26. if not os.path.isfile(file):
  27. # 过滤数据集错误:不存在对应音频
  28. print(f"没有找到对应的音频:{file}")
  29. countNotFound += 1
  30. continue
  31. files.add(file)
  32. files = list(files)
  33. else:
  34. files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
  35. print(f"Found {len(files)} files")
  36. files = [str(file.relative_to(root)) for file in tqdm(files)]
  37. Random(42).shuffle(files)
  38. if val_count is not None:
  39. if val_count < 1 or val_count > len(files):
  40. raise ValueError("val_count must be between 1 and number of files")
  41. val_size = val_count
  42. else:
  43. val_size = math.ceil(len(files) * val_ratio)
  44. with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
  45. f.write("\n".join(files[val_size:]))
  46. with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
  47. f.write("\n".join(files[:val_size]))
  48. print("Done")
  49. if __name__ == "__main__":
  50. main()