|
|
@@ -3,7 +3,8 @@ import io
|
|
|
import time
|
|
|
import traceback
|
|
|
from http import HTTPStatus
|
|
|
-from typing import Annotated, Any, Literal, Optional
|
|
|
+from threading import Lock
|
|
|
+from typing import Annotated, Literal, Optional
|
|
|
|
|
|
import numpy as np
|
|
|
import soundfile as sf
|
|
|
@@ -82,9 +83,7 @@ class LlamaModel:
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
logger.info(f"Time to load model: {time.time() - self.t0:.02f} seconds")
|
|
|
-
|
|
|
- if self.tokenizer is None:
|
|
|
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
|
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
|
|
|
|
if self.compile:
|
|
|
logger.info("Compiling model ...")
|
|
|
@@ -106,10 +105,9 @@ class LlamaModel:
|
|
|
|
|
|
|
|
|
class VQGANModel:
|
|
|
- def __init__(self, config_name: str, checkpoint_path: str):
|
|
|
- if self.cfg is None:
|
|
|
- with initialize(version_base="1.3", config_path="../fish_speech/configs"):
|
|
|
- self.cfg = compose(config_name=config_name)
|
|
|
+ def __init__(self, config_name: str, checkpoint_path: str, device: str):
|
|
|
+ with initialize(version_base="1.3", config_path="../fish_speech/configs"):
|
|
|
+ self.cfg = compose(config_name=config_name)
|
|
|
|
|
|
self.model = instantiate(self.cfg.model)
|
|
|
state_dict = torch.load(
|
|
|
@@ -120,8 +118,9 @@ class VQGANModel:
|
|
|
state_dict = state_dict["state_dict"]
|
|
|
self.model.load_state_dict(state_dict, strict=True)
|
|
|
self.model.eval()
|
|
|
- self.model.cuda()
|
|
|
- logger.info("Restored model from checkpoint")
|
|
|
+ self.model.to(device)
|
|
|
+
|
|
|
+ logger.info("Restored VQGAN model from checkpoint")
|
|
|
|
|
|
def __del__(self):
|
|
|
self.cfg = None
|
|
|
@@ -175,7 +174,6 @@ class VQGANModel:
|
|
|
class LoadLlamaModelRequest(BaseModel):
|
|
|
config_name: str = "text2semantic_finetune"
|
|
|
checkpoint_path: str = "checkpoints/text2semantic-400m-v0.2-4k.pth"
|
|
|
- device: str = "cuda"
|
|
|
precision: Literal["float16", "bfloat16"] = "bfloat16"
|
|
|
tokenizer: str = "fishaudio/speech-lm-v1"
|
|
|
compile: bool = True
|
|
|
@@ -186,15 +184,20 @@ class LoadVQGANModelRequest(BaseModel):
|
|
|
checkpoint_path: str = "checkpoints/vqgan-v1.pth"
|
|
|
|
|
|
|
|
|
+class LoadModelRequest(BaseModel):
|
|
|
+ device: str = "cuda"
|
|
|
+ llama: LoadLlamaModelRequest
|
|
|
+ vqgan: LoadVQGANModelRequest
|
|
|
+
|
|
|
+
|
|
|
class LoadModelResponse(BaseModel):
|
|
|
name: str
|
|
|
|
|
|
|
|
|
@routes.http.put("/models/{name}")
|
|
|
-def load_model(
|
|
|
+def api_load_model(
|
|
|
name: Annotated[str, Path("default")],
|
|
|
- llama: Annotated[LoadLlamaModelRequest, Body()],
|
|
|
- vqgan: Annotated[LoadVQGANModelRequest, Body()],
|
|
|
+ req: Annotated[LoadModelRequest, Body(exclusive=True)],
|
|
|
) -> Annotated[LoadModelResponse, JSONResponse[200, {}, LoadModelResponse]]:
|
|
|
"""
|
|
|
Load model
|
|
|
@@ -203,12 +206,15 @@ def load_model(
|
|
|
if name in MODELS:
|
|
|
del MODELS[name]
|
|
|
|
|
|
+ llama = req.llama
|
|
|
+ vqgan = req.vqgan
|
|
|
+
|
|
|
logger.info("Loading model ...")
|
|
|
new_model = {
|
|
|
"llama": LlamaModel(
|
|
|
config_name=llama.config_name,
|
|
|
checkpoint_path=llama.checkpoint_path,
|
|
|
- device=llama.device,
|
|
|
+ device=req.device,
|
|
|
precision=llama.precision,
|
|
|
tokenizer_path=llama.tokenizer,
|
|
|
compile=llama.compile,
|
|
|
@@ -216,7 +222,9 @@ def load_model(
|
|
|
"vqgan": VQGANModel(
|
|
|
config_name=vqgan.config_name,
|
|
|
checkpoint_path=vqgan.checkpoint_path,
|
|
|
+ device=req.device,
|
|
|
),
|
|
|
+ "lock": Lock(),
|
|
|
}
|
|
|
|
|
|
MODELS[name] = new_model
|
|
|
@@ -225,7 +233,7 @@ def load_model(
|
|
|
|
|
|
|
|
|
@routes.http.delete("/models/{name}")
|
|
|
-def delete_model(
|
|
|
+def api_delete_model(
|
|
|
name: Annotated[str, Path("default")],
|
|
|
) -> JSONResponse[200, {}, dict]:
|
|
|
"""
|
|
|
@@ -238,6 +246,8 @@ def delete_model(
|
|
|
content="Model not found.",
|
|
|
)
|
|
|
|
|
|
+ del MODELS[name]
|
|
|
+
|
|
|
return JSONResponse(
|
|
|
dict(message="Model deleted."),
|
|
|
200,
|
|
|
@@ -245,7 +255,7 @@ def delete_model(
|
|
|
|
|
|
|
|
|
@routes.http.get("/models")
|
|
|
-def list_models() -> JSONResponse[200, {}, dict]:
|
|
|
+def api_list_models() -> JSONResponse[200, {}, dict]:
|
|
|
"""
|
|
|
List models
|
|
|
"""
|
|
|
@@ -271,7 +281,7 @@ class InvokeRequest(BaseModel):
|
|
|
|
|
|
|
|
|
@routes.http.post("/models/{name}/invoke")
|
|
|
-def invoke_model(
|
|
|
+def api_invoke_model(
|
|
|
name: Annotated[str, Path("default")],
|
|
|
req: Annotated[InvokeRequest, Body(exclusive=True)],
|
|
|
):
|
|
|
@@ -289,6 +299,9 @@ def invoke_model(
|
|
|
llama_model_manager = model["llama"]
|
|
|
vqgan_model_manager = model["vqgan"]
|
|
|
|
|
|
+ # Lock
|
|
|
+ model["lock"].acquire()
|
|
|
+
|
|
|
device = llama_model_manager.device
|
|
|
seed = req.seed
|
|
|
prompt_tokens = req.prompt_tokens
|
|
|
@@ -348,6 +361,9 @@ def invoke_model(
|
|
|
codes = codes - 2
|
|
|
assert (codes >= 0).all(), "Codes should be >= 0"
|
|
|
|
|
|
+ # Release lock
|
|
|
+ model["lock"].release()
|
|
|
+
|
|
|
# --------------- llama end ------------
|
|
|
audio, sr = vqgan_model_manager.sematic_to_wav(codes)
|
|
|
# --------------- vqgan end ------------
|
|
|
@@ -358,8 +374,8 @@ def invoke_model(
|
|
|
return StreamResponse(
|
|
|
iterable=[buffer.getvalue()],
|
|
|
headers={
|
|
|
- "Content-Disposition": "attachment; filename=generated.wav",
|
|
|
- "Content-Type": "audio/wav",
|
|
|
+ "Content-Disposition": "attachment; filename=audio.wav",
|
|
|
+ "Content-Type": "application/octet-stream",
|
|
|
},
|
|
|
)
|
|
|
|