create_train_split.py 804 B

12345678910111213141516171819202122232425262728293031
  1. from pathlib import Path
  2. from random import Random
  3. import click
  4. from tqdm import tqdm
  5. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
  6. @click.command()
  7. @click.argument("root", type=click.Path(exists=True, path_type=Path))
  8. def main(root):
  9. files = list_files(root, AUDIO_EXTENSIONS, recursive=True, show_progress=True)
  10. print(f"Found {len(files)} files")
  11. files = [str(file) for file in tqdm(files) if file.with_suffix(".npy").exists()]
  12. print(f"Found {len(files)} files with features")
  13. Random(42).shuffle(files)
  14. with open(root / "vq_train_filelist.txt", "w") as f:
  15. f.write("\n".join(files[:-100]))
  16. with open(root / "vq_val_filelist.txt", "w") as f:
  17. f.write("\n".join(files[-100:]))
  18. print("Done")
  19. if __name__ == "__main__":
  20. main()