create_train_split.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import math
  2. from pathlib import Path
  3. from random import Random
  4. import click
  5. from tqdm import tqdm
  6. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  7. @click.command()
  8. @click.argument("root", type=click.Path(exists=True, path_type=Path))
  9. @click.option("--val-ratio", type=float, default=0.2)
  10. @click.option("--val-count", type=int, default=None)
  11. def main(root, val_ratio, val_count):
  12. files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
  13. print(f"Found {len(files)} files")
  14. files = [str(file.relative_to(root)) for file in tqdm(files)]
  15. Random(42).shuffle(files)
  16. if val_count is not None:
  17. if val_count < 1 or val_count > len(files):
  18. raise ValueError("val_count must be between 1 and number of files")
  19. val_size = val_count
  20. else:
  21. val_size = math.ceil(len(files) * val_ratio)
  22. with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
  23. f.write("\n".join(files[val_size:]))
  24. with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
  25. f.write("\n".join(files[:val_size]))
  26. print("Done")
  27. if __name__ == "__main__":
  28. main()