Selaa lähdekoodia

Support VITS Filelist Input (#18)

* add vits filelist support

* add vits filelist support

* Update create_train_split.py

* Update create_train_split.py

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* fix list not subscriptable

* fix list not subscriptable

* fix path lib

* Add files via upload

* Add files via upload

* fix parent

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add files via upload

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add files via upload

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add files via upload

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Stardust·减 2 vuotta sitten
vanhempi
commit
42e1442ffd
3 muutettua tiedostoa jossa 78 lisäystä ja 10 poistoa
  1. 30 7
      tools/llama/build_dataset.py
  2. 24 2
      tools/vqgan/create_train_split.py
  3. 24 1
      tools/vqgan/extract_vq.py

+ 30 - 7
tools/llama/build_dataset.py

@@ -1,6 +1,8 @@
+import os
 import re
 from collections import defaultdict
 from multiprocessing import Pool
+from pathlib import Path
 
 import click
 import numpy as np
@@ -14,7 +16,7 @@ from fish_speech.text import g2p
 from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 
 
-def task_generator(config):
+def task_generator(config, filelist):
     with open(config, "r") as f:
         config = yaml.load(f, Loader=yaml.FullLoader)
 
@@ -28,7 +30,26 @@ def task_generator(config):
         )
 
         # Load the files
-        files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
+        if filelist:
+            with open(filelist, "r", encoding="utf-8") as f:
+                # files = [Path(line..strip().split("|")[0]) for line in f]
+                files = set()
+                countSame = 0
+                countNotFound = 0
+                for line in f.readlines():
+                    file = Path(line.strip().split("|")[0])
+                    if file in files:
+                        print(f"重复音频文本:{line}")
+                        countSame += 1
+                        continue
+                    if not os.path.isfile(file):
+                        # 过滤数据集错误:不存在对应音频
+                        print(f"没有找到对应的音频:{file}")
+                        countNotFound += 1
+                        continue
+                    files.add(file)
+        else:
+            files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
 
         grouped_files = defaultdict(list)
         for file in files:
@@ -38,7 +59,6 @@ def task_generator(config):
                 p = file.parent.parent.name
             else:
                 raise ValueError(f"Invalid parent level {parent_level}")
-
             grouped_files[p].append(file)
 
         logger.info(f"Found {len(grouped_files)} groups in {root}")
@@ -57,7 +77,6 @@ def run_task(task):
         if np_file.exists() is False or txt_file.exists() is False:
             logger.warning(f"Can't find {np_file} or {txt_file}")
             continue
-
         with open(txt_file, "r") as f:
             text = f.read().strip()
 
@@ -100,10 +119,14 @@ def run_task(task):
     "--config", type=click.Path(), default="fish_speech/configs/data/finetune.yaml"
 )
 @click.option("--output", type=click.Path(), default="data/quantized-dataset-ft.protos")
-def main(config, output):
+@click.option("--filelist", type=click.Path(), default=None)
+@click.option("--num_worker", type=int, default=16)
+def main(config, output, filelist, num_worker):
     dataset_fp = open(output, "wb")
-    with Pool(16) as p:
-        for result in tqdm(p.imap_unordered(run_task, task_generator(config))):
+    with Pool(num_worker) as p:
+        for result in tqdm(
+            p.imap_unordered(run_task, task_generator(config, filelist))
+        ):
             dataset_fp.write(result)
 
     dataset_fp.close()

+ 24 - 2
tools/vqgan/create_train_split.py

@@ -1,4 +1,5 @@
 import math
+import os
 from pathlib import Path
 from random import Random
 
@@ -12,8 +13,29 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 @click.argument("root", type=click.Path(exists=True, path_type=Path))
 @click.option("--val-ratio", type=float, default=0.2)
 @click.option("--val-count", type=int, default=None)
-def main(root, val_ratio, val_count):
-    files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
+@click.option("--filelist", default=None, type=Path)
+def main(root, val_ratio, val_count, filelist):
+    if filelist:
+        with open(filelist, "r", encoding="utf-8") as f:
+            # files = [Path(line..strip().split("|")[0]) for line in f]
+            files = set()
+            countSame = 0
+            countNotFound = 0
+            for line in f.readlines():
+                file = Path(line.strip().split("|")[0])
+                if file in files:
+                    print(f"重复音频文本:{line}")
+                    countSame += 1
+                    continue
+                if not os.path.isfile(file):
+                    # 过滤数据集错误:不存在对应音频
+                    print(f"没有找到对应的音频:{file}")
+                    countNotFound += 1
+                    continue
+                files.add(file)
+        files = list(files)
+    else:
+        files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
     print(f"Found {len(files)} files")
 
     files = [str(file.relative_to(root)) for file in tqdm(files)]

+ 24 - 1
tools/vqgan/extract_vq.py

@@ -145,12 +145,14 @@ def process_batch(files: list[Path], model) -> float:
     default="checkpoints/vqgan-v1.pth",
 )
 @click.option("--batch-size", default=64)
+@click.option("--filelist", default=None, type=Path)
 def main(
     folder: str,
     num_workers: int,
     config_name: str,
     checkpoint_path: str,
     batch_size: int,
+    filelist: Path,
 ):
     if num_workers > 1 and WORLD_SIZE != num_workers:
         assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
@@ -185,7 +187,28 @@ def main(
 
     # This is a worker
     logger.info(f"Starting worker")
-    files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
+    if filelist:
+        with open(filelist, "r", encoding="utf-8") as f:
+            # files = [Path(line..strip().split("|")[0]) for line in f]
+            files = set()
+            countSame = 0
+            countNotFound = 0
+            for line in f.readlines():
+                file = Path(line.strip().split("|")[0])
+                if file in files:
+                    print(f"重复音频文本:{line}")
+                    countSame += 1
+                    continue
+                if not os.path.isfile(file):
+                    # 过滤数据集错误:不存在对应音频
+                    print(f"没有找到对应的音频:{file}")
+                    countNotFound += 1
+                    continue
+                files.add(file)
+        files = list(files)
+        print(f"总重复音频数:{countSame},总未找到的音频数:{countNotFound}")
+    else:
+        files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
     Random(42).shuffle(files)
 
     total_files = len(files)