|
@@ -1,9 +1,9 @@
|
|
|
import math
|
|
import math
|
|
|
-import os
|
|
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from random import Random
|
|
from random import Random
|
|
|
|
|
|
|
|
import click
|
|
import click
|
|
|
|
|
+from loguru import logger
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
|
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
|
@@ -11,7 +11,7 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
|
|
|
|
|
|
|
@click.command()
|
|
@click.command()
|
|
|
@click.argument("root", type=click.Path(exists=True, path_type=Path))
|
|
@click.argument("root", type=click.Path(exists=True, path_type=Path))
|
|
|
-@click.option("--val-ratio", type=float, default=0.2)
|
|
|
|
|
|
|
+@click.option("--val-ratio", type=float, default=None)
|
|
|
@click.option("--val-count", type=int, default=None)
|
|
@click.option("--val-count", type=int, default=None)
|
|
|
@click.option("--filelist", default=None, type=Path)
|
|
@click.option("--filelist", default=None, type=Path)
|
|
|
def main(root, val_ratio, val_count, filelist):
|
|
def main(root, val_ratio, val_count, filelist):
|
|
@@ -20,25 +20,34 @@ def main(root, val_ratio, val_count, filelist):
|
|
|
else:
|
|
else:
|
|
|
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
|
|
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
|
|
|
|
|
|
|
|
- print(f"Found {len(files)} files")
|
|
|
|
|
|
|
+ logger.info(f"Found {len(files)} files")
|
|
|
files = [str(file.relative_to(root)) for file in tqdm(files)]
|
|
files = [str(file.relative_to(root)) for file in tqdm(files)]
|
|
|
|
|
|
|
|
Random(42).shuffle(files)
|
|
Random(42).shuffle(files)
|
|
|
|
|
|
|
|
- if val_count is not None:
|
|
|
|
|
|
|
+ if val_count is None and val_ratio is None:
|
|
|
|
|
+ logger.info("Validation ratio and count not specified, using max(20%, 100)")
|
|
|
|
|
+ val_size = max(1, math.ceil(len(files) * 0.2))
|
|
|
|
|
+ elif val_count is not None and val_ratio is not None:
|
|
|
|
|
+ logger.error("Cannot specify both val_count and val_ratio")
|
|
|
|
|
+ return
|
|
|
|
|
+ elif val_count is not None:
|
|
|
if val_count < 1 or val_count > len(files):
|
|
if val_count < 1 or val_count > len(files):
|
|
|
- raise ValueError("val_count must be between 1 and number of files")
|
|
|
|
|
|
|
+ logger.error("val_count must be between 1 and number of files")
|
|
|
|
|
+ return
|
|
|
val_size = val_count
|
|
val_size = val_count
|
|
|
else:
|
|
else:
|
|
|
val_size = math.ceil(len(files) * val_ratio)
|
|
val_size = math.ceil(len(files) * val_ratio)
|
|
|
|
|
|
|
|
|
|
+ logger.info(f"Using {val_size} files for validation")
|
|
|
|
|
+
|
|
|
with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
|
|
with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
|
|
|
f.write("\n".join(files[val_size:]))
|
|
f.write("\n".join(files[val_size:]))
|
|
|
|
|
|
|
|
with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
|
|
with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
|
|
|
f.write("\n".join(files[:val_size]))
|
|
f.write("\n".join(files[:val_size]))
|
|
|
|
|
|
|
|
- print("Done")
|
|
|
|
|
|
|
+ logger.info("Done")
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|