Parcourir la source

Optimize api server

Lengyue il y a 2 ans
Parent
commit
95d90c8452
2 fichiers modifiés avec 38 ajouts et 21 suppressions
  1. 2 1
      pyproject.toml
  2. 36 20
      tools/api_server.py

+ 2 - 1
pyproject.toml

@@ -34,7 +34,8 @@ dependencies = [
     "wandb",
     "tensorboard",
     "grpcio>=1.58.0",
-    "kui>=1.6.0"
+    "kui>=1.6.0",
+    "zibai-server>=0.9.0"
 ]
 
 [build-system]

+ 36 - 20
tools/api_server.py

@@ -3,7 +3,8 @@ import io
 import time
 import traceback
 from http import HTTPStatus
-from typing import Annotated, Any, Literal, Optional
+from threading import Lock
+from typing import Annotated, Literal, Optional
 
 import numpy as np
 import soundfile as sf
@@ -82,9 +83,7 @@ class LlamaModel:
 
         torch.cuda.synchronize()
         logger.info(f"Time to load model: {time.time() - self.t0:.02f} seconds")
-
-        if self.tokenizer is None:
-            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
 
         if self.compile:
             logger.info("Compiling model ...")
@@ -106,10 +105,9 @@ class LlamaModel:
 
 
 class VQGANModel:
-    def __init__(self, config_name: str, checkpoint_path: str):
-        if self.cfg is None:
-            with initialize(version_base="1.3", config_path="../fish_speech/configs"):
-                self.cfg = compose(config_name=config_name)
+    def __init__(self, config_name: str, checkpoint_path: str, device: str):
+        with initialize(version_base="1.3", config_path="../fish_speech/configs"):
+            self.cfg = compose(config_name=config_name)
 
         self.model = instantiate(self.cfg.model)
         state_dict = torch.load(
@@ -120,8 +118,9 @@ class VQGANModel:
             state_dict = state_dict["state_dict"]
         self.model.load_state_dict(state_dict, strict=True)
         self.model.eval()
-        self.model.cuda()
-        logger.info("Restored model from checkpoint")
+        self.model.to(device)
+
+        logger.info("Restored VQGAN model from checkpoint")
 
     def __del__(self):
         self.cfg = None
@@ -175,7 +174,6 @@ class VQGANModel:
 class LoadLlamaModelRequest(BaseModel):
     config_name: str = "text2semantic_finetune"
     checkpoint_path: str = "checkpoints/text2semantic-400m-v0.2-4k.pth"
-    device: str = "cuda"
     precision: Literal["float16", "bfloat16"] = "bfloat16"
     tokenizer: str = "fishaudio/speech-lm-v1"
     compile: bool = True
@@ -186,15 +184,20 @@ class LoadVQGANModelRequest(BaseModel):
     checkpoint_path: str = "checkpoints/vqgan-v1.pth"
 
 
+class LoadModelRequest(BaseModel):
+    device: str = "cuda"
+    llama: LoadLlamaModelRequest
+    vqgan: LoadVQGANModelRequest
+
+
 class LoadModelResponse(BaseModel):
     name: str
 
 
 @routes.http.put("/models/{name}")
-def load_model(
+def api_load_model(
     name: Annotated[str, Path("default")],
-    llama: Annotated[LoadLlamaModelRequest, Body()],
-    vqgan: Annotated[LoadVQGANModelRequest, Body()],
+    req: Annotated[LoadModelRequest, Body(exclusive=True)],
 ) -> Annotated[LoadModelResponse, JSONResponse[200, {}, LoadModelResponse]]:
     """
     Load model
@@ -203,12 +206,15 @@ def load_model(
     if name in MODELS:
         del MODELS[name]
 
+    llama = req.llama
+    vqgan = req.vqgan
+
     logger.info("Loading model ...")
     new_model = {
         "llama": LlamaModel(
             config_name=llama.config_name,
             checkpoint_path=llama.checkpoint_path,
-            device=llama.device,
+            device=req.device,
             precision=llama.precision,
             tokenizer_path=llama.tokenizer,
             compile=llama.compile,
@@ -216,7 +222,9 @@ def load_model(
         "vqgan": VQGANModel(
             config_name=vqgan.config_name,
             checkpoint_path=vqgan.checkpoint_path,
+            device=req.device,
         ),
+        "lock": Lock(),
     }
 
     MODELS[name] = new_model
@@ -225,7 +233,7 @@ def load_model(
 
 
 @routes.http.delete("/models/{name}")
-def delete_model(
+def api_delete_model(
     name: Annotated[str, Path("default")],
 ) -> JSONResponse[200, {}, dict]:
     """
@@ -238,6 +246,8 @@ def delete_model(
             content="Model not found.",
         )
 
+    del MODELS[name]
+
     return JSONResponse(
         dict(message="Model deleted."),
         200,
@@ -245,7 +255,7 @@ def delete_model(
 
 
 @routes.http.get("/models")
-def list_models() -> JSONResponse[200, {}, dict]:
+def api_list_models() -> JSONResponse[200, {}, dict]:
     """
     List models
     """
@@ -271,7 +281,7 @@ class InvokeRequest(BaseModel):
 
 
 @routes.http.post("/models/{name}/invoke")
-def invoke_model(
+def api_invoke_model(
     name: Annotated[str, Path("default")],
     req: Annotated[InvokeRequest, Body(exclusive=True)],
 ):
@@ -289,6 +299,9 @@ def invoke_model(
     llama_model_manager = model["llama"]
     vqgan_model_manager = model["vqgan"]
 
+    # Lock
+    model["lock"].acquire()
+
     device = llama_model_manager.device
     seed = req.seed
     prompt_tokens = req.prompt_tokens
@@ -348,6 +361,9 @@ def invoke_model(
     codes = codes - 2
     assert (codes >= 0).all(), "Codes should be >= 0"
 
+    # Release lock
+    model["lock"].release()
+
     # --------------- llama end ------------
     audio, sr = vqgan_model_manager.sematic_to_wav(codes)
     # --------------- vqgan end ------------
@@ -358,8 +374,8 @@ def invoke_model(
     return StreamResponse(
         iterable=[buffer.getvalue()],
         headers={
-            "Content-Disposition": "attachment; filename=generated.wav",
-            "Content-Type": "audio/wav",
+            "Content-Disposition": "attachment; filename=audio.wav",
+            "Content-Type": "application/octet-stream",
         },
     )