Просмотр исходного кода

Fix API server bugs. (#1019)

* Remove unused code.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove rest asr code.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix API Server bugs.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
PoTaTo 10 месяцев назад
Родитель
Сommit
23a4beb069

+ 1 - 1
tools/llama/quantize.py

@@ -451,7 +451,7 @@ def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -
         precision=precision,
         compile=False,
     )
-    vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+    vq_model = "codec.pth"
     now = timestamp if timestamp != "None" else generate_folder_name()
 
     if mode == "int8":

+ 1 - 1
tools/server/api_utils.py

@@ -22,7 +22,7 @@ def parse_args():
     parser.add_argument(
         "--decoder-checkpoint-path",
         type=str,
-        default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+        default="checkpoints/openaudio-s1-mini/codec.pth",
     )
     parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
     parser.add_argument("--device", type=str, default="cuda")

+ 0 - 5
tools/server/model_manager.py

@@ -15,7 +15,6 @@ class ModelManager:
         device: str,
         half: bool,
         compile: bool,
-        asr_enabled: bool,
         llama_checkpoint_path: str,
         decoder_checkpoint_path: str,
         decoder_config_name: str,
@@ -36,10 +35,6 @@ class ModelManager:
             self.device = "cpu"
             logger.info("CUDA is not available, running on CPU.")
 
-        # Load the ASR model if enabled
-        if asr_enabled:
-            self.load_asr_model(self.device)
-
         # Load the TTS models
         self.load_llama_model(
             llama_checkpoint_path, self.device, self.precision, self.compile, self.mode

+ 0 - 1
tools/server/views.py

@@ -34,7 +34,6 @@ from tools.server.api_utils import (
 from tools.server.inference import inference_wrapper as inference
 from tools.server.model_manager import ModelManager
 from tools.server.model_utils import (
-    batch_asr,
     batch_vqgan_decode,
     cached_vqgan_batch_encode,
 )

+ 2 - 2
tools/vqgan/extract_vq.py

@@ -47,7 +47,7 @@ logger.add(sys.stderr, format=logger_format)
 @lru_cache(maxsize=1)
 def get_model(
     config_name: str = "modded_dac_vq",
-    checkpoint_path: str = "checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+    checkpoint_path: str = "checkpoints/openaudio-s1-mini/codec.pth",
     device: str | torch.device = "cuda",
 ):
     with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
@@ -138,7 +138,7 @@ def process_batch(files: list[Path], model) -> float:
 @click.option("--config-name", default="modded_dac_vq")
 @click.option(
     "--checkpoint-path",
-    default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
+    default="checkpoints/openaudio-s1-mini/codec.pth",
 )
 @click.option("--batch-size", default=64)
 @click.option("--filelist", default=None, type=Path)