Просмотр исходного кода

Update api server & dockerfile

Lengyue 1 год назад
Родитель
Сommit
dd403f4d98
5 измененных файлов с 273 добавлено и 428 удалено
  1. 4 22
      dockerfile
  2. 1 1
      docs/en/inference.md
  3. 2 2
      docs/zh/inference.md
  4. 266 0
      tools/api.py
  5. 0 403
      tools/api_server.py

+ 4 - 22
dockerfile

@@ -1,4 +1,4 @@
-FROM nvcr.io/nvidia/pytorch:24.02-py3
+FROM python:3.10.14-bookworm
 
 # Install system dependencies
 ENV DEBIAN_FRONTEND=noninteractive
@@ -14,29 +14,11 @@ RUN chsh -s /usr/bin/zsh
 ENV SHELL=/usr/bin/zsh
 
 # Setup torchaudio
-RUN git clone https://github.com/pytorch/audio --recursive --depth 1 && \
-    cd audio && pip install -v --no-use-pep517 . && \
-    cd .. && rm -rf audio && python -c "import torchaudio; print(torchaudio.__version__)"
-
-# Setup flash-attn
-RUN pip3 install --upgrade pip && \
-    pip3 install ninja packaging && \
-    FLASH_ATTENTION_FORCE_BUILD=TRUE pip3 install git+https://github.com/Dao-AILab/flash-attention.git
-
-# Test flash-attn
-RUN python3 -c "from flash_attn import flash_attn_varlen_func"
+RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
 
 # Project Env
 WORKDIR /exp
-COPY pyproject.toml ./
-COPY data_server ./data_server
-COPY fish_speech ./fish_speech
-
-# Setup rust-data-server
-RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \
-    cd data_server && $HOME/.cargo/bin/cargo build --release && cp target/release/data_server /usr/local/bin/ && \
-    cd .. && rm -rf data_server && data_server --help
-
-RUN pip3 install -e . && pip uninstall -y fish-speech && rm -rf fish_speech
+COPY . .
+RUN pip3 install -e .
 
 CMD /bin/zsh

+ 1 - 1
docs/en/inference.md

@@ -67,7 +67,7 @@ python tools/vqgan/inference.py \
 We provide a HTTP API for inference. You can use the following command to start the server:
 
 ```bash
-python -m zibai tools.api_server:app --listen 127.0.0.1:8000
+python -m tools.api --listen 0.0.0.0:8000
 ```
 
 After that, you can view and test the API at http://127.0.0.1:8000/docs.  

+ 2 - 2
docs/zh/inference.md

@@ -72,9 +72,9 @@ python tools/vqgan/inference.py \
 运行以下命令来启动 HTTP 服务:
 
 ```bash
-python -m zibai tools.api_server:app --listen 127.0.0.1:8000
+python -m tools.api --listen 0.0.0.0:8000
 # 推荐中国大陆用户运行以下命令来启动 HTTP 服务:
-HF_ENDPOINT=https://hf-mirror.com python -m zibai tools.api_server:app --listen 127.0.0.1:8000
+HF_ENDPOINT=https://hf-mirror.com python -m tools.api --listen 0.0.0.0:8000
 ```
 
 随后, 你可以在 `http://127.0.0.1:8000/docs` 中查看并测试 API.  

+ 266 - 0
tools/api.py

@@ -0,0 +1,266 @@
+import base64
+import io
+import traceback
+from argparse import ArgumentParser
+from http import HTTPStatus
+from threading import Lock
+from typing import Annotated, Literal, Optional
+
+import librosa
+import numpy as np
+import soundfile as sf
+import torch
+from kui.wsgi import (
+    Body,
+    HTTPException,
+    HttpView,
+    JSONResponse,
+    Kui,
+    OpenAPI,
+    StreamResponse,
+    allow_cors,
+)
+from kui.wsgi.routing import MultimethodRoutes
+from loguru import logger
+from pydantic import BaseModel
+from transformers import AutoTokenizer
+
+from tools.llama.generate import generate_long
+from tools.llama.generate import load_model as load_llama_model
+from tools.vqgan.inference import load_model as load_vqgan_model
+from tools.webui import inference
+
+lock = Lock()
+
+
+# Define utils for web server
+def http_execption_handler(exc: HTTPException):
+    return JSONResponse(
+        dict(
+            statusCode=exc.status_code,
+            message=exc.content,
+            error=HTTPStatus(exc.status_code).phrase,
+        ),
+        exc.status_code,
+        exc.headers,
+    )
+
+
+def other_exception_handler(exc: "Exception"):
+    traceback.print_exc()
+
+    status = HTTPStatus.INTERNAL_SERVER_ERROR
+    return JSONResponse(
+        dict(statusCode=status, message=str(exc), error=status.phrase),
+        status,
+    )
+
+
+routes = MultimethodRoutes(base_class=HttpView)
+
+
+class InvokeRequest(BaseModel):
+    text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
+    reference_text: Optional[str] = None
+    reference_audio: Optional[str] = None
+    max_new_tokens: int = 0
+    chunk_length: int = 30
+    top_k: int = 0
+    top_p: float = 0.7
+    repetition_penalty: float = 1.5
+    temperature: float = 0.7
+    speaker: Optional[str] = None
+    format: Literal["wav", "mp3", "flac"] = "wav"
+
+
+def inference(req: InvokeRequest):
+    # Parse reference audio aka prompt
+    prompt_tokens = None
+    if req.reference_audio is not None:
+        buffer = io.BytesIO(base64.b64decode(req.reference_audio))
+        reference_audio_content, _ = librosa.load(
+            buffer, sr=vqgan_model.sampling_rate, mono=True
+        )
+        audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
+            None, None, :
+        ]
+
+        logger.info(
+            f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
+        )
+
+        # VQ Encoder
+        audio_lengths = torch.tensor(
+            [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
+        )
+        prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
+
+    # LLAMA Inference
+    result = generate_long(
+        model=llama_model,
+        tokenizer=llama_tokenizer,
+        device=vqgan_model.device,
+        decode_one_token=decode_one_token,
+        max_new_tokens=req.max_new_tokens,
+        text=req.text,
+        top_k=int(req.top_k) if req.top_k > 0 else None,
+        top_p=req.top_p,
+        repetition_penalty=req.repetition_penalty,
+        temperature=req.temperature,
+        compile=args.compile,
+        iterative_prompt=req.chunk_length > 0,
+        chunk_length=req.chunk_length,
+        max_length=args.max_length,
+        speaker=req.speaker,
+        prompt_tokens=prompt_tokens,
+        prompt_text=req.reference_text,
+    )
+
+    codes = next(result)
+
+    # VQGAN Inference
+    feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
+    fake_audios = vqgan_model.decode(
+        indices=codes[None], feature_lengths=feature_lengths, return_audios=True
+    )[0, 0]
+
+    fake_audios = fake_audios.float().cpu().numpy()
+
+    return fake_audios
+
+
+@routes.http.post("/invoke")
+def api_invoke_model(
+    req: Annotated[InvokeRequest, Body(exclusive=True)],
+):
+    """
+    Invoke model and generate audio
+    """
+
+    if args.max_gradio_length > 0 and len(req.text) > args.max_gradio_length:
+        raise HTTPException(
+            HTTPStatus.BAD_REQUEST,
+            f"Text is too long, max length is {args.max_gradio_length}",
+        )
+
+    try:
+        # Lock, avoid interrupting the inference process
+        lock.acquire()
+        fake_audios = inference(req)
+    except Exception as e:
+        raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
+    finally:
+        # Release lock
+        lock.release()
+
+    buffer = io.BytesIO()
+    sf.write(buffer, fake_audios, vqgan_model.sampling_rate, format=req.format)
+
+    return StreamResponse(
+        iterable=[buffer.getvalue()],
+        headers={
+            "Content-Disposition": f"attachment; filename=audio.{req.format}",
+            "Content-Type": "application/octet-stream",
+        },
+    )
+
+
+@routes.http.post("/health")
+def api_health():
+    """
+    Health check
+    """
+
+    return JSONResponse({"status": "ok"})
+
+
+def parse_args():
+    parser = ArgumentParser()
+    parser.add_argument(
+        "--llama-checkpoint-path",
+        type=str,
+        default="checkpoints/text2semantic-medium-v1-2k.pth",
+    )
+    parser.add_argument(
+        "--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
+    )
+    parser.add_argument(
+        "--vqgan-checkpoint-path",
+        type=str,
+        default="checkpoints/vq-gan-group-fsq-2x1024.pth",
+    )
+    parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
+    parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
+    parser.add_argument("--device", type=str, default="cuda")
+    parser.add_argument("--half", action="store_true")
+    parser.add_argument("--max-length", type=int, default=2048)
+    parser.add_argument("--compile", action="store_true")
+    parser.add_argument("--max-gradio-length", type=int, default=0)
+    parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
+
+    return parser.parse_args()
+
+
+# Define Kui app
+app = Kui(
+    exception_handlers={
+        HTTPException: http_execption_handler,
+        Exception: other_exception_handler,
+    },
+    cors_config={},
+)
+
+# Swagger UI & routes
+app.router << ("/v1" // routes) << ("/docs" // OpenAPI().routes)
+args = parse_args()
+
+
+if __name__ == "__main__":
+    from zibai import Options, main
+
+    options = Options(
+        app="tools.api:app",
+        listen=[args.listen],
+    )
+    main(options)
+else:
+    args.precision = torch.half if args.half else torch.bfloat16
+
+    logger.info("Loading Llama model...")
+    llama_model, decode_one_token = load_llama_model(
+        config_name=args.llama_config_name,
+        checkpoint_path=args.llama_checkpoint_path,
+        device=args.device,
+        precision=args.precision,
+        max_length=args.max_length,
+        compile=args.compile,
+    )
+    llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
+    logger.info("Llama model loaded, loading VQ-GAN model...")
+
+    vqgan_model = load_vqgan_model(
+        config_name=args.vqgan_config_name,
+        checkpoint_path=args.vqgan_checkpoint_path,
+        device=args.device,
+    )
+
+    logger.info("VQ-GAN model loaded, warming up...")
+
+    # Dry run to check if the model is loaded correctly and avoid the first-time latency
+    inference(
+        InvokeRequest(
+            text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+            reference_text=None,
+            reference_audio=None,
+            max_new_tokens=0,
+            chunk_length=30,
+            top_k=0,
+            top_p=0.7,
+            repetition_penalty=1.5,
+            temperature=0.7,
+            speaker=None,
+            format="wav",
+        )
+    )
+
+    logger.info("Warming up done.")

+ 0 - 403
tools/api_server.py

@@ -1,403 +0,0 @@
-import gc
-import io
-import time
-import traceback
-from http import HTTPStatus
-from threading import Lock
-from typing import Annotated, Literal, Optional
-
-import librosa
-import numpy as np
-import soundfile as sf
-import torch
-import torch.nn.functional as F
-from hydra import compose, initialize
-from hydra.utils import instantiate
-from kui.wsgi import (
-    Body,
-    HTTPException,
-    HttpView,
-    JSONResponse,
-    Kui,
-    OpenAPI,
-    Path,
-    StreamResponse,
-    allow_cors,
-)
-from kui.wsgi.routing import MultimethodRoutes, Router
-from loguru import logger
-from pydantic import BaseModel
-from transformers import AutoTokenizer
-
-import tools.llama.generate
-from fish_speech.models.vqgan.utils import sequence_mask
-from tools.llama.generate import encode_tokens, generate, load_model
-
-
-# Define utils for web server
-def http_execption_handler(exc: HTTPException):
-    return JSONResponse(
-        dict(
-            statusCode=exc.status_code,
-            message=exc.content,
-            error=HTTPStatus(exc.status_code).phrase,
-        ),
-        exc.status_code,
-        exc.headers,
-    )
-
-
-def other_exception_handler(exc: "Exception"):
-    traceback.print_exc()
-
-    status = HTTPStatus.INTERNAL_SERVER_ERROR
-    return JSONResponse(
-        dict(statusCode=status, message=str(exc), error=status.phrase),
-        status,
-    )
-
-
-routes = MultimethodRoutes(base_class=HttpView)
-
-# Define models
-MODELS = {}
-
-
-class LlamaModel:
-    def __init__(
-        self,
-        config_name: str,
-        checkpoint_path: str,
-        device,
-        precision: str,
-        tokenizer_path: str,
-        compile: bool,
-    ):
-        self.device = device
-        self.compile = compile
-
-        self.t0 = time.time()
-        self.precision = torch.bfloat16 if precision == "bfloat16" else torch.float16
-        self.model = load_model(config_name, checkpoint_path, device, self.precision)
-        self.model_size = sum(
-            p.numel() for p in self.model.parameters() if p.requires_grad
-        )
-
-        torch.cuda.synchronize()
-        logger.info(f"Time to load model: {time.time() - self.t0:.02f} seconds")
-        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
-
-        if self.compile:
-            logger.info("Compiling model ...")
-            tools.llama.generate.decode_one_token = torch.compile(
-                tools.llama.generate.decode_one_token,
-                mode="reduce-overhead",
-                fullgraph=True,
-            )
-
-    def __del__(self):
-        self.model = None
-        self.tokenizer = None
-
-        gc.collect()
-        if torch.cuda.is_available():
-            torch.cuda.empty_cache()
-
-        logger.info("The llama is removed from memory.")
-
-
-class VQGANModel:
-    def __init__(self, config_name: str, checkpoint_path: str, device: str):
-        with initialize(version_base="1.3", config_path="../fish_speech/configs"):
-            self.cfg = compose(config_name=config_name)
-
-        self.model = instantiate(self.cfg.model)
-        state_dict = torch.load(
-            checkpoint_path,
-            map_location=self.model.device,
-        )
-        if "state_dict" in state_dict:
-            state_dict = state_dict["state_dict"]
-        self.model.load_state_dict(state_dict, strict=True)
-        self.model.eval()
-        self.model.to(device)
-
-        logger.info("Restored VQGAN model from checkpoint")
-
-    def __del__(self):
-        self.cfg = None
-        self.model = None
-
-        gc.collect()
-        if torch.cuda.is_available():
-            torch.cuda.empty_cache()
-
-        logger.info("The vqgan model is removed from memory.")
-
-    @torch.no_grad()
-    def sematic_to_wav(self, indices):
-        model = self.model
-        indices = indices.to(model.device).long()
-        feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
-        decoded = model.decode(indices=indices[None], feature_lengths=feature_lengths)
-
-        # Save audio
-        fake_audio = decoded.audios[0, 0].cpu().numpy().astype(np.float32)
-
-        return fake_audio, model.sampling_rate
-
-    @torch.no_grad()
-    def wav_to_semantic(self, audio):
-        model = self.model
-        # Load audio
-        audio, _ = librosa.load(
-            audio,
-            sr=model.sampling_rate,
-            mono=True,
-        )
-        audios = torch.from_numpy(audio).to(model.device)[None, None, :]
-        logger.info(
-            f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
-        )
-
-        # VQ Encoder
-        audio_lengths = torch.tensor(
-            [audios.shape[2]], device=model.device, dtype=torch.long
-        )
-        encoded = model.encode(audios, audio_lengths)
-        indices = encoded.indices[0]
-
-        logger.info(f"Generated indices of shape {indices.shape}")
-        return indices
-
-
-class LoadLlamaModelRequest(BaseModel):
-    config_name: str = "text2semantic_finetune"
-    checkpoint_path: str = "checkpoints/text2semantic-400m-v0.2-4k.pth"
-    precision: Literal["float16", "bfloat16"] = "bfloat16"
-    tokenizer: str = "fishaudio/speech-lm-v1"
-    compile: bool = True
-
-
-class LoadVQGANModelRequest(BaseModel):
-    config_name: str = "vqgan_pretrain"
-    checkpoint_path: str = "checkpoints/vqgan-v1.pth"
-
-
-class LoadModelRequest(BaseModel):
-    device: str = "cuda"
-    llama: LoadLlamaModelRequest
-    vqgan: LoadVQGANModelRequest
-
-
-class LoadModelResponse(BaseModel):
-    name: str
-
-
-@routes.http.put("/models/{name}")
-def api_load_model(
-    name: Annotated[str, Path("default")],
-    req: Annotated[LoadModelRequest, Body(exclusive=True)],
-) -> Annotated[LoadModelResponse, JSONResponse[200, {}, LoadModelResponse]]:
-    """
-    Load model
-    """
-
-    if name in MODELS:
-        del MODELS[name]
-
-    llama = req.llama
-    vqgan = req.vqgan
-
-    logger.info("Loading model ...")
-    new_model = {
-        "llama": LlamaModel(
-            config_name=llama.config_name,
-            checkpoint_path=llama.checkpoint_path,
-            device=req.device,
-            precision=llama.precision,
-            tokenizer_path=llama.tokenizer,
-            compile=llama.compile,
-        ),
-        "vqgan": VQGANModel(
-            config_name=vqgan.config_name,
-            checkpoint_path=vqgan.checkpoint_path,
-            device=req.device,
-        ),
-        "lock": Lock(),
-    }
-
-    MODELS[name] = new_model
-
-    return LoadModelResponse(name=name)
-
-
-@routes.http.delete("/models/{name}")
-def api_delete_model(
-    name: Annotated[str, Path("default")],
-) -> JSONResponse[200, {}, dict]:
-    """
-    Delete model
-    """
-
-    if name not in MODELS:
-        raise HTTPException(
-            status_code=HTTPStatus.BAD_REQUEST,
-            content="Model not found.",
-        )
-
-    del MODELS[name]
-
-    return JSONResponse(
-        dict(message="Model deleted."),
-        200,
-    )
-
-
-@routes.http.get("/models")
-def api_list_models() -> JSONResponse[200, {}, dict]:
-    """
-    List models
-    """
-
-    return JSONResponse(
-        dict(models=list(MODELS.keys())),
-        200,
-    )
-
-
-class InvokeRequest(BaseModel):
-    text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
-    prompt_text: Optional[str] = None
-    prompt_tokens: Optional[str] = None
-    max_new_tokens: int = 0
-    top_k: Optional[int] = None
-    top_p: float = 0.5
-    repetition_penalty: float = 1.5
-    temperature: float = 0.7
-    order: str = "zh,jp,en"
-    use_g2p: bool = True
-    seed: Optional[int] = None
-    speaker: Optional[str] = None
-
-
-@routes.http.post("/models/{name}/invoke")
-def api_invoke_model(
-    name: Annotated[str, Path("default")],
-    req: Annotated[InvokeRequest, Body(exclusive=True)],
-):
-    """
-    Invoke model and generate audio
-    """
-
-    if name not in MODELS:
-        raise HTTPException(
-            status_code=HTTPStatus.NOT_FOUND,
-            content="Cannot find model.",
-        )
-
-    model = MODELS[name]
-    llama_model_manager = model["llama"]
-    vqgan_model_manager = model["vqgan"]
-
-    device = llama_model_manager.device
-    seed = req.seed
-    prompt_tokens = req.prompt_tokens
-    logger.info(f"Device: {device}")
-
-    if prompt_tokens is not None and prompt_tokens.endswith(".npy"):
-        prompt_tokens = torch.from_numpy(np.load(prompt_tokens)).to(device)
-    elif prompt_tokens is not None and prompt_tokens.endswith(".wav"):
-        prompt_tokens = vqgan_model_manager.wav_to_semantic(prompt_tokens)
-    elif prompt_tokens is not None:
-        logger.error(f"Unknown prompt tokens: {prompt_tokens}")
-        raise HTTPException(
-            status_code=HTTPStatus.BAD_REQUEST,
-            content="Unknown prompt tokens, it should be either .npy or .wav file.",
-        )
-    else:
-        prompt_tokens = None
-
-    # Lock
-    model["lock"].acquire()
-
-    encoded = encode_tokens(
-        llama_model_manager.tokenizer,
-        req.text,
-        prompt_text=req.prompt_text,
-        prompt_tokens=prompt_tokens,
-        bos=True,
-        device=device,
-        use_g2p=req.use_g2p,
-        speaker=req.speaker,
-        order=req.order,
-    )
-    prompt_length = encoded.size(1)
-    logger.info(f"Encoded prompt shape: {encoded.shape}")
-
-    if seed is not None:
-        torch.manual_seed(seed)
-        torch.cuda.manual_seed(seed)
-
-    torch.cuda.synchronize()
-
-    t0 = time.perf_counter()
-    y = generate(
-        model=llama_model_manager.model,
-        prompt=encoded,
-        max_new_tokens=req.max_new_tokens,
-        eos_token_id=llama_model_manager.tokenizer.eos_token_id,
-        precision=llama_model_manager.precision,
-        temperature=req.temperature,
-        top_k=req.top_k,
-        top_p=req.top_p,
-        repetition_penalty=req.repetition_penalty,
-    )
-
-    torch.cuda.synchronize()
-    t = time.perf_counter() - t0
-
-    tokens_generated = y.size(1) - prompt_length
-    tokens_sec = tokens_generated / t
-    logger.info(
-        f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
-    )
-    logger.info(
-        f"Bandwidth achieved: {llama_model_manager.model_size * tokens_sec / 1e9:.02f} GB/s"
-    )
-    logger.info(f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
-
-    codes = y[1:, prompt_length:-1]
-    codes = codes - 2
-    assert (codes >= 0).all(), "Codes should be >= 0"
-
-    # Release lock
-    model["lock"].release()
-
-    # --------------- llama end ------------
-    audio, sr = vqgan_model_manager.sematic_to_wav(codes)
-    # --------------- vqgan end ------------
-
-    buffer = io.BytesIO()
-    sf.write(buffer, audio, sr, format="wav")
-
-    return StreamResponse(
-        iterable=[buffer.getvalue()],
-        headers={
-            "Content-Disposition": "attachment; filename=audio.wav",
-            "Content-Type": "application/octet-stream",
-        },
-    )
-
-
-# Define Kui app
-app = Kui(
-    exception_handlers={
-        HTTPException: http_execption_handler,
-        Exception: other_exception_handler,
-    },
-    cors_config={},
-)
-
-# Swagger UI & routes
-app.router << ("/v1" // routes) << ("/docs" // OpenAPI().routes)