Kaynağa Gözat

add fastapi for inference (#12)

* fastapi for infer

* fastapi for infer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rm unused code & move server

* Clean up code & better api server

* update api server

* Add http server deps

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lengyue <lengyue@lengyue.me>
spicysama 2 yıl önce
ebeveyn
işleme
06a35aef53

+ 3 - 1
fish_speech/models/text2semantic/llama.py

@@ -110,7 +110,9 @@ class Transformer(nn.Module):
         self.max_batch_size = -1
         self.max_seq_len = -1
 
-    def setup_caches(self, max_batch_size, max_seq_len, dtype=torch.bfloat16):
+    def setup_caches(
+        self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
+    ):
         if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
             return
 

+ 2 - 1
pyproject.toml

@@ -33,7 +33,8 @@ dependencies = [
     "pyopenjtalk",
     "wandb",
     "tensorboard",
-    "grpcio>=1.58.0"
+    "grpcio>=1.58.0",
+    "kui>=1.6.0"
 ]
 
 [build-system]

+ 384 - 0
tools/api_server.py

@@ -0,0 +1,384 @@
+import gc
+import io
+import time
+import traceback
+from http import HTTPStatus
+from typing import Annotated, Any, Literal, Optional
+
+import numpy as np
+import soundfile as sf
+import torch
+import torch.nn.functional as F
+from hydra import compose, initialize
+from hydra.utils import instantiate
+from kui.wsgi import (
+    Body,
+    HTTPException,
+    HttpView,
+    JSONResponse,
+    Kui,
+    OpenAPI,
+    Path,
+    StreamResponse,
+    allow_cors,
+)
+from kui.wsgi.routing import MultimethodRoutes, Router
+from loguru import logger
+from pydantic import BaseModel
+from transformers import AutoTokenizer
+
+import tools.llama.generate
+from tools.llama.generate import encode_tokens, generate, load_model
+
+
+# Define utils for web server
+def http_execption_handler(exc: HTTPException):
+    return JSONResponse(
+        dict(
+            statusCode=exc.status_code,
+            message=exc.content,
+            error=HTTPStatus(exc.status_code).phrase,
+        ),
+        exc.status_code,
+        exc.headers,
+    )
+
+
+def other_exception_handler(exc: "Exception"):
+    traceback.print_exc()
+
+    status = HTTPStatus.INTERNAL_SERVER_ERROR
+    return JSONResponse(
+        dict(statusCode=status, message=str(exc), error=status.phrase),
+        status,
+    )
+
+
+routes = MultimethodRoutes(base_class=HttpView)
+
+# Define models
+MODELS = {}
+
+
+class LlamaModel:
+    def __init__(
+        self,
+        config_name: str,
+        checkpoint_path: str,
+        device,
+        precision: str,
+        tokenizer_path: str,
+        compile: bool,
+    ):
+        self.device = device
+        self.compile = compile
+
+        self.t0 = time.time()
+        self.precision = torch.bfloat16 if precision == "bfloat16" else torch.float16
+        self.model = load_model(config_name, checkpoint_path, device, self.precision)
+        self.model_size = sum(
+            p.numel() for p in self.model.parameters() if p.requires_grad
+        )
+
+        torch.cuda.synchronize()
+        logger.info(f"Time to load model: {time.time() - self.t0:.02f} seconds")
+
+        if self.tokenizer is None:
+            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+
+        if self.compile:
+            logger.info("Compiling model ...")
+            tools.llama.generate.decode_one_token = torch.compile(
+                tools.llama.generate.decode_one_token,
+                mode="reduce-overhead",
+                fullgraph=True,
+            )
+
+    def __del__(self):
+        self.model = None
+        self.tokenizer = None
+
+        gc.collect()
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+
+        logger.info("The llama is removed from memory.")
+
+
+class VQGANModel:
+    def __init__(self, config_name: str, checkpoint_path: str):
+        if self.cfg is None:
+            with initialize(version_base="1.3", config_path="../fish_speech/configs"):
+                self.cfg = compose(config_name=config_name)
+
+        self.model = instantiate(self.cfg.model)
+        state_dict = torch.load(
+            checkpoint_path,
+            map_location=self.model.device,
+        )
+        if "state_dict" in state_dict:
+            state_dict = state_dict["state_dict"]
+        self.model.load_state_dict(state_dict, strict=True)
+        self.model.eval()
+        self.model.cuda()
+        logger.info("Restored model from checkpoint")
+
+    def __del__(self):
+        self.cfg = None
+        self.model = None
+
+        gc.collect()
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+
+        logger.info("The vqgan model is removed from memory.")
+
+    @torch.no_grad()
+    @torch.autocast(device_type="cuda", enabled=True)
+    def sematic_to_wav(self, indices):
+        model = self.model
+        indices = indices.to(model.device).long()
+        indices = indices.unsqueeze(1).unsqueeze(-1)
+
+        mel_lengths = indices.shape[2] * (
+            model.downsample.total_strides if model.downsample is not None else 1
+        )
+        mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
+        mel_masks = torch.ones(
+            (1, 1, mel_lengths), device=model.device, dtype=torch.float32
+        )
+
+        text_features = model.vq_encoder.decode(indices)
+
+        logger.info(
+            f"VQ Encoded, indices: {indices.shape} equivalent to "
+            + f"{1 / (mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
+        )
+
+        text_features = F.interpolate(
+            text_features, size=mel_lengths[0], mode="nearest"
+        )
+
+        # Sample mels
+        decoded_mels = model.decoder(text_features, mel_masks)
+        fake_audios = model.generator(decoded_mels)
+        logger.info(
+            f"Generated audio of shape {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
+        )
+
+        # Save audio
+        fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
+
+        return fake_audio, model.sampling_rate
+
+
+class LoadLlamaModelRequest(BaseModel):
+    config_name: str = "text2semantic_finetune"
+    checkpoint_path: str = "checkpoints/text2semantic-400m-v0.2-4k.pth"
+    device: str = "cuda"
+    precision: Literal["float16", "bfloat16"] = "bfloat16"
+    tokenizer: str = "fishaudio/speech-lm-v1"
+    compile: bool = True
+
+
+class LoadVQGANModelRequest(BaseModel):
+    config_name: str = "vqgan_pretrain"
+    checkpoint_path: str = "checkpoints/vqgan-v1.pth"
+
+
+class LoadModelResponse(BaseModel):
+    name: str
+
+
+@routes.http.put("/models/{name}")
+def load_model(
+    name: Annotated[str, Path("default")],
+    llama: Annotated[LoadLlamaModelRequest, Body()],
+    vqgan: Annotated[LoadVQGANModelRequest, Body()],
+) -> Annotated[LoadModelResponse, JSONResponse[200, {}, LoadModelResponse]]:
+    """
+    Load model
+    """
+
+    if name in MODELS:
+        del MODELS[name]
+
+    logger.info("Loading model ...")
+    new_model = {
+        "llama": LlamaModel(
+            config_name=llama.config_name,
+            checkpoint_path=llama.checkpoint_path,
+            device=llama.device,
+            precision=llama.precision,
+            tokenizer_path=llama.tokenizer,
+            compile=llama.compile,
+        ),
+        "vqgan": VQGANModel(
+            config_name=vqgan.config_name,
+            checkpoint_path=vqgan.checkpoint_path,
+        ),
+    }
+
+    MODELS[name] = new_model
+
+    return LoadModelResponse(name=name)
+
+
+@routes.http.delete("/models/{name}")
+def delete_model(
+    name: Annotated[str, Path("default")],
+) -> JSONResponse[200, {}, dict]:
+    """
+    Delete model
+    """
+
+    if name not in MODELS:
+        raise HTTPException(
+            status_code=HTTPStatus.BAD_REQUEST,
+            content="Model not found.",
+        )
+
+    return JSONResponse(
+        dict(message="Model deleted."),
+        200,
+    )
+
+
+@routes.http.get("/models")
+def list_models() -> JSONResponse[200, {}, dict]:
+    """
+    List models
+    """
+
+    return JSONResponse(
+        dict(models=list(MODELS.keys())),
+        200,
+    )
+
+
+class InvokeRequest(BaseModel):
+    text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
+    prompt_text: Optional[str] = None
+    prompt_tokens: Optional[str] = None
+    max_new_tokens: int = 0
+    top_k: Optional[int] = None
+    top_p: float = 0.5
+    repetition_penalty: float = 1.5
+    temperature: float = 0.7
+    use_g2p: bool = True
+    seed: Optional[int] = None
+    speaker: Optional[str] = None
+
+
+@routes.http.post("/models/{name}/invoke")
+def invoke_model(
+    name: Annotated[str, Path("default")],
+    req: Annotated[InvokeRequest, Body(exclusive=True)],
+):
+    """
+    Invoke model and generate audio
+    """
+
+    if name not in MODELS:
+        raise HTTPException(
+            status_code=HTTPStatus.NOT_FOUND,
+            content="Cannot find model.",
+        )
+
+    model = MODELS[name]
+    llama_model_manager = model["llama"]
+    vqgan_model_manager = model["vqgan"]
+
+    device = llama_model_manager.device
+    seed = req.seed
+    prompt_tokens = req.prompt_tokens
+    logger.info(f"Device: {device}")
+
+    prompt_tokens = (
+        torch.from_numpy(np.load(prompt_tokens)).to(device)
+        if prompt_tokens is not None
+        else None
+    )
+    encoded = encode_tokens(
+        llama_model_manager.tokenizer,
+        req.text,
+        prompt_text=req.prompt_text,
+        prompt_tokens=prompt_tokens,
+        bos=True,
+        device=device,
+        use_g2p=req.use_g2p,
+        speaker=req.speaker,
+    )
+    prompt_length = encoded.size(1)
+    logger.info(f"Encoded prompt shape: {encoded.shape}")
+
+    if seed is not None:
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed(seed)
+
+    torch.cuda.synchronize()
+
+    t0 = time.perf_counter()
+    y = generate(
+        model=llama_model_manager.model,
+        prompt=encoded,
+        max_new_tokens=req.max_new_tokens,
+        eos_token_id=llama_model_manager.tokenizer.eos_token_id,
+        precision=llama_model_manager.precision,
+        temperature=req.temperature,
+        top_k=req.top_k,
+        top_p=req.top_p,
+        repetition_penalty=req.repetition_penalty,
+    )
+
+    torch.cuda.synchronize()
+    t = time.perf_counter() - t0
+
+    tokens_generated = y.size(1) - prompt_length
+    tokens_sec = tokens_generated / t
+    logger.info(
+        f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
+    )
+    logger.info(
+        f"Bandwidth achieved: {llama_model_manager.model_size * tokens_sec / 1e9:.02f} GB/s"
+    )
+    logger.info(f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
+
+    codes = y[1:, prompt_length:-1]
+    codes = codes - 2
+    assert (codes >= 0).all(), "Codes should be >= 0"
+
+    # --------------- llama end ------------
+    audio, sr = vqgan_model_manager.sematic_to_wav(codes)
+    # --------------- vqgan end ------------
+
+    buffer = io.BytesIO()
+    sf.write(buffer, audio, sr, format="wav")
+
+    return StreamResponse(
+        iterable=[buffer.getvalue()],
+        headers={
+            "Content-Disposition": "attachment; filename=generated.wav",
+            "Content-Type": "audio/wav",
+        },
+    )
+
+
+# Define Kui app
+app = Kui(
+    exception_handlers={
+        HTTPException: http_execption_handler,
+        Exception: other_exception_handler,
+    },
+)
+app.router = Router(
+    [],
+    http_middlewares=[
+        app.exception_middleware,
+        allow_cors(),
+    ],
+)
+
+# Swagger UI & routes
+app.router << ("/v1" // routes)
+app.router << ("/docs" // OpenAPI().routes)

+ 6 - 4
tools/llama/generate.py

@@ -180,10 +180,10 @@ def decode_n_tokens(
             enable_flash=False, enable_mem_efficient=False, enable_math=True
         ):  # Actually better for Inductor to codegen attention here
             next_token = decode_one_token(
-                model,
-                cur_token,
-                input_pos,
-                window,
+                model=model,
+                x=cur_token,
+                input_pos=input_pos,
+                previous_tokens=window,
                 **sampling_kwargs,
             )
 
@@ -434,6 +434,8 @@ def main(
     logger.info(f"Encoded prompt shape: {encoded.shape}")
 
     torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+
     if compile:
         global decode_one_token
         decode_one_token = torch.compile(