|
|
@@ -1,3 +1,4 @@
|
|
|
+import math
|
|
|
from pathlib import Path
|
|
|
from random import Random
|
|
|
|
|
|
@@ -9,7 +10,9 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
|
|
|
|
|
|
@click.command()
|
|
|
@click.argument("root", type=click.Path(exists=True, path_type=Path))
|
|
|
-def main(root):
|
|
|
+@click.option("--val-ratio", type=float, default=0.2)
|
|
|
+@click.option("--val-count", type=int, default=None)
|
|
|
+def main(root, val_ratio, val_count):
|
|
|
files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
|
|
|
print(f"Found {len(files)} files")
|
|
|
|
|
|
@@ -17,11 +20,18 @@ def main(root):
|
|
|
|
|
|
Random(42).shuffle(files)
|
|
|
|
|
|
- with open(root / "vq_train_filelist.txt", "w") as f:
|
|
|
- f.write("\n".join(files[:-100]))
|
|
|
+ if val_count is not None:
|
|
|
+ if val_count < 1 or val_count > len(files):
|
|
|
+ raise ValueError("val_count must be between 1 and number of files")
|
|
|
+ val_size = val_count
|
|
|
+ else:
|
|
|
+ val_size = math.ceil(len(files) * val_ratio)
|
|
|
|
|
|
- with open(root / "vq_val_filelist.txt", "w") as f:
|
|
|
- f.write("\n".join(files[-100:]))
|
|
|
+ with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
|
|
|
+ f.write("\n".join(files[val_size:]))
|
|
|
+
|
|
|
+ with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
|
|
|
+ f.write("\n".join(files[:val_size]))
|
|
|
|
|
|
print("Done")
|
|
|
|