Преглед изворни кода

Improve file list and vq performance (#131)

* Update extract_vq.py

* Update file.py

* Update extract_vq.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Stardust·减 пре 1 година
родитељ
комит
4b335e218e
2 измењених фајлова са 10 додато и 8 уклоњено
  1. 4 3
      fish_speech/utils/file.py
  2. 6 5
      tools/vqgan/extract_vq.py

+ 4 - 3
fish_speech/utils/file.py

@@ -44,10 +44,11 @@ def list_files(
     if not path.exists():
         raise FileNotFoundError(f"Directory {path} does not exist.")
 
-    files = [Path(f) for f in glob(str(path / "**/*"), recursive=recursive)]
+    # files = [Path(f) for f in glob(str(path / "**/*"), recursive=recursive)]
+    files = [file for ext in extensions for file in directory.glob(f"**/*{ext}")]
 
-    if extensions is not None:
-        files = [f for f in files if f.suffix in extensions]
+    # if extensions is not None:
+    #    files = [f for f in files if f.suffix in extensions]
 
     if sort:
         files = natsorted(files)

+ 6 - 5
tools/vqgan/extract_vq.py

@@ -41,7 +41,7 @@ logger.add(sys.stderr, format=logger_format)
 
 @lru_cache(maxsize=1)
 def get_model(
-    config_name: str = "vqgan",
+    config_name: str = "vqgan_pretrain",
     checkpoint_path: str = "checkpoints/vqgan/step_000380000.ckpt",
 ):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
@@ -72,7 +72,9 @@ def process_batch(files: list[Path], model) -> float:
 
     for file in files:
         try:
-            wav, sr = torchaudio.load(file)
+            wav, sr = torchaudio.load(
+                str(file), backend="sox"
+            )  # Need to install libsox-dev
         except Exception as e:
             logger.error(f"Error reading {file}: {e}")
             continue
@@ -169,11 +171,10 @@ def main(
     if filelist:
         files = [i[0] for i in load_filelist(filelist)]
     else:
-        files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=True)
+        files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
 
     print(f"Found {len(files)} files")
-    files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
-    Random(42).shuffle(files)
+    # files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
 
     total_files = len(files)
     files = files[RANK::WORLD_SIZE]