create_train_split.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import math
  2. from pathlib import Path
  3. from random import Random
  4. import click
  5. from loguru import logger
  6. from pydub import AudioSegment
  7. from tqdm import tqdm
  8. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
  9. @click.command()
  10. @click.argument("root", type=click.Path(exists=True, path_type=Path))
  11. @click.option("--val-ratio", type=float, default=None)
  12. @click.option("--val-count", type=int, default=None)
  13. @click.option("--filelist", default=None, type=Path)
  14. @click.option("--min-duration", default=0.2, type=float)
  15. @click.option("--max-duration", default=30, type=float)
  16. def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
  17. if filelist:
  18. files = [i[0] for i in load_filelist(filelist)]
  19. else:
  20. files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
  21. filtered_files = []
  22. for file in tqdm(files):
  23. try:
  24. audio = AudioSegment.from_file(str(file))
  25. duration = len(audio) / 1000.0
  26. if min_duration <= duration <= max_duration:
  27. filtered_files.append(str(file.relative_to(root)))
  28. except Exception as e:
  29. logger.info(f"Error processing {file}: {e}")
  30. logger.info(f"Found {len(files)} files | Got Filtered {len(filtered_files)} files")
  31. Random(42).shuffle(filtered_files)
  32. if val_count is None and val_ratio is None:
  33. logger.info("Validation ratio and count not specified, using min(20%, 100)")
  34. val_size = min(100, math.ceil(len(filtered_files) * 0.2))
  35. elif val_count is not None and val_ratio is not None:
  36. logger.error("Cannot specify both val_count and val_ratio")
  37. return
  38. elif val_count is not None:
  39. if val_count < 1 or val_count > len(filtered_files):
  40. logger.error("val_count must be between 1 and number of files")
  41. return
  42. val_size = val_count
  43. else:
  44. val_size = math.ceil(len(filtered_files) * val_ratio)
  45. logger.info(f"Using {val_size} files for validation")
  46. with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
  47. f.write("\n".join(filtered_files[val_size:]))
  48. with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
  49. f.write("\n".join(filtered_files[:val_size]))
  50. logger.info("Done")
  51. if __name__ == "__main__":
  52. main()