create_train_split.py 712 B

123456789101112131415161718192021222324252627282930
  1. from pathlib import Path
  2. from random import Random
  3. import click
  4. from tqdm import tqdm
  5. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  6. @click.command()
  7. @click.argument("root", type=click.Path(exists=True, path_type=Path))
  8. def main(root):
  9. files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
  10. print(f"Found {len(files)} files")
  11. files = [str(file.relative_to(root)) for file in tqdm(files)]
  12. Random(42).shuffle(files)
  13. with open(root / "vq_train_filelist.txt", "w") as f:
  14. f.write("\n".join(files[:-100]))
  15. with open(root / "vq_val_filelist.txt", "w") as f:
  16. f.write("\n".join(files[-100:]))
  17. print("Done")
  18. if __name__ == "__main__":
  19. main()