Browse Source

Optimize validation set logic

Lengyue 2 năm trước cách đây
mục cha
commit
198ccfa4a0
1 tập tin đã thay đổi với 15 bổ sung6 xóa
  1. 15 6
      tools/vqgan/create_train_split.py

+ 15 - 6
tools/vqgan/create_train_split.py

@@ -1,9 +1,9 @@
 import math
-import os
 from pathlib import Path
 from random import Random
 
 import click
+from loguru import logger
 from tqdm import tqdm
 
 from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
@@ -11,7 +11,7 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
 
 @click.command()
 @click.argument("root", type=click.Path(exists=True, path_type=Path))
-@click.option("--val-ratio", type=float, default=0.2)
+@click.option("--val-ratio", type=float, default=None)
 @click.option("--val-count", type=int, default=None)
 @click.option("--filelist", default=None, type=Path)
 def main(root, val_ratio, val_count, filelist):
@@ -20,25 +20,34 @@ def main(root, val_ratio, val_count, filelist):
     else:
         files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
 
-    print(f"Found {len(files)} files")
+    logger.info(f"Found {len(files)} files")
     files = [str(file.relative_to(root)) for file in tqdm(files)]
 
     Random(42).shuffle(files)
 
-    if val_count is not None:
+    if val_count is None and val_ratio is None:
+        logger.info("Validation ratio and count not specified, using max(20%, 100)")
+        val_size = max(1, math.ceil(len(files) * 0.2))
+    elif val_count is not None and val_ratio is not None:
+        logger.error("Cannot specify both val_count and val_ratio")
+        return
+    elif val_count is not None:
         if val_count < 1 or val_count > len(files):
-            raise ValueError("val_count must be between 1 and number of files")
+            logger.error("val_count must be between 1 and number of files")
+            return
         val_size = val_count
     else:
         val_size = math.ceil(len(files) * val_ratio)
 
+    logger.info(f"Using {val_size} files for validation")
+
     with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
         f.write("\n".join(files[val_size:]))
 
     with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
         f.write("\n".join(files[:val_size]))
 
-    print("Done")
+    logger.info("Done")
 
 
 if __name__ == "__main__":