create_train_split.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import math
  2. from pathlib import Path
  3. from random import Random
  4. import click
  5. from loguru import logger
  6. from pydub import AudioSegment
  7. from tqdm import tqdm
  8. from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
  9. @click.command()
  10. @click.argument("root", type=click.Path(exists=True, path_type=Path))
  11. @click.option("--val-ratio", type=float, default=None)
  12. @click.option("--val-count", type=int, default=None)
  13. @click.option("--filelist", default=None, type=Path)
  14. @click.option("--min-duration", default=None, type=float)
  15. @click.option("--max-duration", default=None, type=float)
  16. def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
  17. if filelist:
  18. files = [i[0] for i in load_filelist(filelist)]
  19. else:
  20. files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
  21. if min_duration is None and max_duration is None:
  22. filtered_files = list(map(str, [file.relative_to(root) for file in files]))
  23. else:
  24. filtered_files = []
  25. for file in tqdm(files):
  26. try:
  27. audio = AudioSegment.from_file(str(file))
  28. duration = len(audio) / 1000.0
  29. if min_duration is not None and duration < min_duration:
  30. logger.info(
  31. f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
  32. )
  33. continue
  34. if max_duration is not None and duration > max_duration:
  35. logger.info(
  36. f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
  37. )
  38. continue
  39. filtered_files.append(str(file.relative_to(root)))
  40. except Exception as e:
  41. logger.info(f"Error processing {file}: {e}")
  42. logger.info(
  43. f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
  44. )
  45. Random(42).shuffle(filtered_files)
  46. if val_count is None and val_ratio is None:
  47. logger.info("Validation ratio and count not specified, using min(20%, 100)")
  48. val_size = min(100, math.ceil(len(filtered_files) * 0.2))
  49. elif val_count is not None and val_ratio is not None:
  50. logger.error("Cannot specify both val_count and val_ratio")
  51. return
  52. elif val_count is not None:
  53. if val_count < 1 or val_count > len(filtered_files):
  54. logger.error("val_count must be between 1 and number of files")
  55. return
  56. val_size = val_count
  57. else:
  58. val_size = math.ceil(len(filtered_files) * val_ratio)
  59. logger.info(f"Using {val_size} files for validation")
  60. with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
  61. f.write("\n".join(filtered_files[val_size:]))
  62. with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
  63. f.write("\n".join(filtered_files[:val_size]))
  64. logger.info("Done")
  65. if __name__ == "__main__":
  66. main()