Przeglądaj źródła

Use a ratio to split train vs val set in tools/vqgan/create_train_split.py (#11)

* feat(tools): vqgan/create_train_split: use a ratio to split train vs val set

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* feat(tools): vqgan/create_train_split: support using --val_ratio and --val_count to adjust split

* feat(tools): vqgan/create_train_split: support using --val_ratio and --val_count to adjust split

* fix(tools): vqgan/create_train_split: fix wrong option name

* Clean some code

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Leng Yue <lengyue@lengyue.me>
Jesse Cheng 2 lat temu
rodzic
commit
8dac6ec35c
1 zmienionych plików z 15 dodań i 5 usunięć
  1. 15 5
      tools/vqgan/create_train_split.py

+ 15 - 5
tools/vqgan/create_train_split.py

@@ -1,3 +1,4 @@
+import math
 from pathlib import Path
 from random import Random
 
@@ -9,7 +10,9 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 
 @click.command()
 @click.argument("root", type=click.Path(exists=True, path_type=Path))
-def main(root):
+@click.option("--val-ratio", type=float, default=0.2)
+@click.option("--val-count", type=int, default=None)
+def main(root, val_ratio, val_count):
     files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
     print(f"Found {len(files)} files")
 
@@ -17,11 +20,18 @@ def main(root):
 
     Random(42).shuffle(files)
 
-    with open(root / "vq_train_filelist.txt", "w") as f:
-        f.write("\n".join(files[:-100]))
+    if val_count is not None:
+        if val_count < 1 or val_count > len(files):
+            raise ValueError("val_count must be between 1 and number of files")
+        val_size = val_count
+    else:
+        val_size = math.ceil(len(files) * val_ratio)
 
-    with open(root / "vq_val_filelist.txt", "w") as f:
-        f.write("\n".join(files[-100:]))
+    with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
+        f.write("\n".join(files[val_size:]))
+
+    with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
+        f.write("\n".join(files[:val_size]))
 
     print("Done")