create_train_split.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import math
  2. from pathlib import Path
  3. from random import Random
  4. import click
  5. from loguru import logger
  6. from tqdm import tqdm
  7. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
  8. @click.command()
  9. @click.argument("root", type=click.Path(exists=True, path_type=Path))
  10. @click.option("--val-ratio", type=float, default=None)
  11. @click.option("--val-count", type=int, default=None)
  12. @click.option("--filelist", default=None, type=Path)
  13. def main(root, val_ratio, val_count, filelist):
  14. if filelist:
  15. files = [i[0] for i in load_filelist(filelist)]
  16. else:
  17. files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
  18. logger.info(f"Found {len(files)} files")
  19. files = [str(file.relative_to(root)) for file in tqdm(files)]
  20. Random(42).shuffle(files)
  21. if val_count is None and val_ratio is None:
  22. logger.info("Validation ratio and count not specified, using min(20%, 100)")
  23. val_size = min(100, math.ceil(len(files) * 0.2))
  24. elif val_count is not None and val_ratio is not None:
  25. logger.error("Cannot specify both val_count and val_ratio")
  26. return
  27. elif val_count is not None:
  28. if val_count < 1 or val_count > len(files):
  29. logger.error("val_count must be between 1 and number of files")
  30. return
  31. val_size = val_count
  32. else:
  33. val_size = math.ceil(len(files) * val_ratio)
  34. logger.info(f"Using {val_size} files for validation")
  35. with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
  36. f.write("\n".join(files[val_size:]))
  37. with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
  38. f.write("\n".join(files[:val_size]))
  39. logger.info("Done")
  40. if __name__ == "__main__":
  41. main()