Jelajahi Sumber

Optimize create train split function

Lengyue 1 tahun lalu
induk
melakukan
7cebbc27ef
1 mengubah file dengan 29 tambahan dan 11 penghapusan
  1. 29 11
      tools/vqgan/create_train_split.py

+ 29 - 11
tools/vqgan/create_train_split.py

@@ -15,25 +15,43 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
 @click.option("--val-ratio", type=float, default=None)
 @click.option("--val-count", type=int, default=None)
 @click.option("--filelist", default=None, type=Path)
-@click.option("--min-duration", default=0.2, type=float)
-@click.option("--max-duration", default=30, type=float)
+@click.option("--min-duration", default=None, type=float)
+@click.option("--max-duration", default=None, type=float)
 def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
     if filelist:
         files = [i[0] for i in load_filelist(filelist)]
     else:
         files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
 
-    filtered_files = []
-    for file in tqdm(files):
-        try:
-            audio = AudioSegment.from_file(str(file))
-            duration = len(audio) / 1000.0
-            if min_duration <= duration <= max_duration:
+    if min_duration is None and max_duration is None:
+        filtered_files = files
+    else:
+        filtered_files = []
+        for file in tqdm(files):
+            try:
+                audio = AudioSegment.from_file(str(file))
+                duration = len(audio) / 1000.0
+
+                if min_duration is not None and duration < min_duration:
+                    logger.info(
+                        f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
+                    )
+                    continue
+
+                if max_duration is not None and duration > max_duration:
+                    logger.info(
+                        f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
+                    )
+                    continue
+
                 filtered_files.append(str(file.relative_to(root)))
-        except Exception as e:
-            logger.info(f"Error processing {file}: {e}")
+            except Exception as e:
+                logger.info(f"Error processing {file}: {e}")
+
+    logger.info(
+        f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
+    )
 
-    logger.info(f"Found {len(files)} files | Got Filtered {len(filtered_files)} files")
     Random(42).shuffle(filtered_files)
 
     if val_count is None and val_ratio is None: