Explorar el Código

Optimize bert-vits2 parsing

Lengyue hace 2 años
padre
commit
4f02d630f8

+ 50 - 0
fish_speech/utils/file.py

@@ -2,6 +2,8 @@ import os
 from pathlib import Path
 from typing import Union
 
+from loguru import logger
+
 AUDIO_EXTENSIONS = {
     ".mp3",
     ".wav",
@@ -72,3 +74,51 @@ def get_latest_checkpoint(path: Path | str) -> Path | None:
         return None
 
     return ckpts[-1]
+
+
+def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
+    """
+    Load a Bert-VITS2 style filelist.
+    """
+
+    files = set()
+    results = []
+    count_duplicated, count_not_found = 0, 0
+
+    LANGUAGE_TO_LANGUAGES = {
+        "zh": ["zh", "en"],
+        "jp": ["jp", "en"],
+        "en": ["en"],
+    }
+
+    with open(path, "r", encoding="utf-8") as f:
+        for line in f.readlines():
+            filename, speaker, language, text = line.strip().split("|")
+            file = Path(filename)
+            language = language.strip().lower()
+
+            if language == "ja":
+                language = "jp"
+
+            assert language in ["zh", "jp", "en"], f"Invalid language {language}"
+            languages = LANGUAGE_TO_LANGUAGES[language]
+
+            if file in files:
+                logger.warning(f"Duplicated file: {file}")
+                count_duplicated += 1
+                continue
+
+            if not file.exists():
+                logger.warning(f"File not found: {file}")
+                count_not_found += 1
+                continue
+
+            results.append((file, speaker, languages, text))
+
+    if count_duplicated > 0:
+        logger.warning(f"Total duplicated files: {count_duplicated}")
+
+    if count_not_found > 0:
+        logger.warning(f"Total files not found: {count_not_found}")
+
+    return results

+ 32 - 29
tools/llama/build_dataset.py

@@ -13,10 +13,10 @@ from tqdm import tqdm
 from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
 from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
 from fish_speech.text import g2p
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
 
 
-def task_generator(config, filelist):
+def task_generator_yaml(config):
     with open(config, "r") as f:
         config = yaml.load(f, Loader=yaml.FullLoader)
 
@@ -30,26 +30,7 @@ def task_generator(config, filelist):
         )
 
         # Load the files
-        if filelist:
-            with open(filelist, "r", encoding="utf-8") as f:
-                # files = [Path(line..strip().split("|")[0]) for line in f]
-                files = set()
-                countSame = 0
-                countNotFound = 0
-                for line in f.readlines():
-                    file = Path(line.strip().split("|")[0])
-                    if file in files:
-                        print(f"重复音频文本:{line}")
-                        countSame += 1
-                        continue
-                    if not os.path.isfile(file):
-                        # 过滤数据集错误:不存在对应音频
-                        print(f"没有找到对应的音频:{file}")
-                        countNotFound += 1
-                        continue
-                    files.add(file)
-        else:
-            files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
+        files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
 
         grouped_files = defaultdict(list)
         for file in files:
@@ -63,22 +44,44 @@ def task_generator(config, filelist):
 
         logger.info(f"Found {len(grouped_files)} groups in {root}")
         for name, subset in grouped_files.items():
-            yield name, subset, source, languages, extension
+            yield name, subset, source, languages, extension, None
+
+
+def task_generator_filelist(filelist):
+    grouped_files = defaultdict(list)
+    for filename, speaker, languages, text in load_filelist(filelist):
+        if speaker in grouped_files:
+            assert (
+                languages == grouped_files[speaker][0][2]
+            ), f"Speaker {speaker} has different languages"
+
+        grouped_files[speaker].append((Path(filename), text, languages))
+
+    logger.info(f"Found {len(grouped_files)} groups in {filelist}")
+    for speaker, (filename, txt, languages) in grouped_files.items():
+        yield speaker, filename, "filelist", languages, None, txt
 
 
 def run_task(task):
-    name, subset, source, languages, extension = task
+    name, subset, source, languages, extension, text = task
 
     # Parse the files
     sentences = []
     for file in subset:
         np_file = file.with_suffix(".npy")
-        txt_file = file.with_suffix(extension)
-        if np_file.exists() is False or txt_file.exists() is False:
-            logger.warning(f"Can't find {np_file} or {txt_file}")
+        if np_file.exists() is False:
+            logger.warning(f"Can't find {np_file}")
             continue
-        with open(txt_file, "r") as f:
-            text = f.read().strip()
+
+        if text is None:
+            txt_file = file.with_suffix(extension)
+
+            if txt_file.exists() is False:
+                logger.warning(f"Can't find {txt_file}")
+                continue
+
+            with open(txt_file, "r") as f:
+                text = f.read().strip()
 
         # Simple cleaning: replace { xxx } and < xxx > with space
         text = re.sub(r"\{.*?\}", " ", text)

+ 3 - 20
tools/vqgan/create_train_split.py

@@ -6,7 +6,7 @@ from random import Random
 import click
 from tqdm import tqdm
 
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
 
 
 @click.command()
@@ -16,28 +16,11 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
 @click.option("--filelist", default=None, type=Path)
 def main(root, val_ratio, val_count, filelist):
     if filelist:
-        with open(filelist, "r", encoding="utf-8") as f:
-            # files = [Path(line..strip().split("|")[0]) for line in f]
-            files = set()
-            countSame = 0
-            countNotFound = 0
-            for line in f.readlines():
-                file = Path(line.strip().split("|")[0])
-                if file in files:
-                    print(f"重复音频文本:{line}")
-                    countSame += 1
-                    continue
-                if not os.path.isfile(file):
-                    # 过滤数据集错误:不存在对应音频
-                    print(f"没有找到对应的音频:{file}")
-                    countNotFound += 1
-                    continue
-                files.add(file)
-        files = list(files)
+        files = [i[0] for i in load_filelist(filelist)]
     else:
         files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
-    print(f"Found {len(files)} files")
 
+    print(f"Found {len(files)} files")
     files = [str(file.relative_to(root)) for file in tqdm(files)]
 
     Random(42).shuffle(files)

+ 3 - 20
tools/vqgan/extract_vq.py

@@ -19,7 +19,7 @@ from loguru import logger
 from omegaconf import OmegaConf
 
 from fish_speech.models.vqgan.utils import sequence_mask
-from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files
+from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
 
 # register eval resolver
 OmegaConf.register_new_resolver("eval", eval)
@@ -188,27 +188,10 @@ def main(
     # This is a worker
     logger.info(f"Starting worker")
     if filelist:
-        with open(filelist, "r", encoding="utf-8") as f:
-            # files = [Path(line..strip().split("|")[0]) for line in f]
-            files = set()
-            countSame = 0
-            countNotFound = 0
-            for line in f.readlines():
-                file = Path(line.strip().split("|")[0])
-                if file in files:
-                    print(f"重复音频文本:{line}")
-                    countSame += 1
-                    continue
-                if not os.path.isfile(file):
-                    # 过滤数据集错误:不存在对应音频
-                    print(f"没有找到对应的音频:{file}")
-                    countNotFound += 1
-                    continue
-                files.add(file)
-        files = list(files)
-        print(f"总重复音频数:{countSame},总未找到的音频数:{countNotFound}")
+        files = [i[0] for i in load_filelist(filelist)]
     else:
         files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
+
     Random(42).shuffle(files)
 
     total_files = len(files)