split_filelist.py 947 B

123456789101112131415161718192021222324252627282930313233
  1. import random
  2. from pathlib import Path
  3. import click
  4. from loguru import logger
  5. @click.command()
  6. @click.argument(
  7. "list-file",
  8. type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path),
  9. )
  10. @click.option("--train-proportion", "-p", type=float, default=0.95)
  11. def main(list_file, train_proportion):
  12. lines = list_file.read_text().splitlines()
  13. logger.info(f"Found {len(lines)} lines in {list_file}")
  14. random.shuffle(lines)
  15. train_size = int(len(lines) * train_proportion)
  16. train_file = list_file.with_suffix(f".train{list_file.suffix}")
  17. train_file.write_text("\n".join(lines[:train_size]))
  18. test_file = list_file.with_suffix(f".test{list_file.suffix}")
  19. test_file.write_text("\n".join(lines[train_size:]))
  20. logger.info(f"Wrote {len(lines[:train_size])} lines to {train_file}")
  21. logger.info(f"Wrote {len(lines[train_size:])} lines to {test_file}")
  22. if __name__ == "__main__":
  23. main()