split_filelist.py 933 B

123456789101112131415161718192021222324252627
  1. from pathlib import Path
  2. import click
  3. import random
  4. from loguru import logger
  5. @click.command()
  6. @click.argument('list-file', type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path))
  7. @click.option('--train-proportion', '-p', type=float, default=0.95)
  8. def main(list_file, train_proportion):
  9. lines = list_file.read_text().splitlines()
  10. logger.info(f'Found {len(lines)} lines in {list_file}')
  11. random.shuffle(lines)
  12. train_size = int(len(lines) * train_proportion)
  13. train_file = list_file.with_suffix(f'.train{list_file.suffix}')
  14. train_file.write_text('\n'.join(lines[:train_size]))
  15. test_file = list_file.with_suffix(f'.test{list_file.suffix}')
  16. test_file.write_text('\n'.join(lines[train_size:]))
  17. logger.info(f'Wrote {len(lines[:train_size])} lines to {train_file}')
  18. logger.info(f'Wrote {len(lines[train_size:])} lines to {test_file}')
  19. if __name__ == '__main__':
  20. main()