|
|
@@ -0,0 +1,384 @@
|
|
|
+import gc
|
|
|
+import io
|
|
|
+import time
|
|
|
+import traceback
|
|
|
+from http import HTTPStatus
|
|
|
+from typing import Annotated, Any, Literal, Optional
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import soundfile as sf
|
|
|
+import torch
|
|
|
+import torch.nn.functional as F
|
|
|
+from hydra import compose, initialize
|
|
|
+from hydra.utils import instantiate
|
|
|
+from kui.wsgi import (
|
|
|
+ Body,
|
|
|
+ HTTPException,
|
|
|
+ HttpView,
|
|
|
+ JSONResponse,
|
|
|
+ Kui,
|
|
|
+ OpenAPI,
|
|
|
+ Path,
|
|
|
+ StreamResponse,
|
|
|
+ allow_cors,
|
|
|
+)
|
|
|
+from kui.wsgi.routing import MultimethodRoutes, Router
|
|
|
+from loguru import logger
|
|
|
+from pydantic import BaseModel
|
|
|
+from transformers import AutoTokenizer
|
|
|
+
|
|
|
+import tools.llama.generate
|
|
|
+from tools.llama.generate import encode_tokens, generate, load_model
|
|
|
+
|
|
|
+
|
|
|
+# Define utils for web server
|
|
|
+def http_execption_handler(exc: HTTPException):
|
|
|
+ return JSONResponse(
|
|
|
+ dict(
|
|
|
+ statusCode=exc.status_code,
|
|
|
+ message=exc.content,
|
|
|
+ error=HTTPStatus(exc.status_code).phrase,
|
|
|
+ ),
|
|
|
+ exc.status_code,
|
|
|
+ exc.headers,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def other_exception_handler(exc: "Exception"):
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
+ status = HTTPStatus.INTERNAL_SERVER_ERROR
|
|
|
+ return JSONResponse(
|
|
|
+ dict(statusCode=status, message=str(exc), error=status.phrase),
|
|
|
+ status,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+routes = MultimethodRoutes(base_class=HttpView)
|
|
|
+
|
|
|
+# Define models
|
|
|
+MODELS = {}
|
|
|
+
|
|
|
+
|
|
|
+class LlamaModel:
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config_name: str,
|
|
|
+ checkpoint_path: str,
|
|
|
+ device,
|
|
|
+ precision: str,
|
|
|
+ tokenizer_path: str,
|
|
|
+ compile: bool,
|
|
|
+ ):
|
|
|
+ self.device = device
|
|
|
+ self.compile = compile
|
|
|
+
|
|
|
+ self.t0 = time.time()
|
|
|
+ self.precision = torch.bfloat16 if precision == "bfloat16" else torch.float16
|
|
|
+ self.model = load_model(config_name, checkpoint_path, device, self.precision)
|
|
|
+ self.model_size = sum(
|
|
|
+ p.numel() for p in self.model.parameters() if p.requires_grad
|
|
|
+ )
|
|
|
+
|
|
|
+ torch.cuda.synchronize()
|
|
|
+ logger.info(f"Time to load model: {time.time() - self.t0:.02f} seconds")
|
|
|
+
|
|
|
+ if self.tokenizer is None:
|
|
|
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
|
+
|
|
|
+ if self.compile:
|
|
|
+ logger.info("Compiling model ...")
|
|
|
+ tools.llama.generate.decode_one_token = torch.compile(
|
|
|
+ tools.llama.generate.decode_one_token,
|
|
|
+ mode="reduce-overhead",
|
|
|
+ fullgraph=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ def __del__(self):
|
|
|
+ self.model = None
|
|
|
+ self.tokenizer = None
|
|
|
+
|
|
|
+ gc.collect()
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+
|
|
|
+ logger.info("The llama is removed from memory.")
|
|
|
+
|
|
|
+
|
|
|
+class VQGANModel:
|
|
|
+ def __init__(self, config_name: str, checkpoint_path: str):
|
|
|
+ if self.cfg is None:
|
|
|
+ with initialize(version_base="1.3", config_path="../fish_speech/configs"):
|
|
|
+ self.cfg = compose(config_name=config_name)
|
|
|
+
|
|
|
+ self.model = instantiate(self.cfg.model)
|
|
|
+ state_dict = torch.load(
|
|
|
+ checkpoint_path,
|
|
|
+ map_location=self.model.device,
|
|
|
+ )
|
|
|
+ if "state_dict" in state_dict:
|
|
|
+ state_dict = state_dict["state_dict"]
|
|
|
+ self.model.load_state_dict(state_dict, strict=True)
|
|
|
+ self.model.eval()
|
|
|
+ self.model.cuda()
|
|
|
+ logger.info("Restored model from checkpoint")
|
|
|
+
|
|
|
+ def __del__(self):
|
|
|
+ self.cfg = None
|
|
|
+ self.model = None
|
|
|
+
|
|
|
+ gc.collect()
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+
|
|
|
+ logger.info("The vqgan model is removed from memory.")
|
|
|
+
|
|
|
+ @torch.no_grad()
|
|
|
+ @torch.autocast(device_type="cuda", enabled=True)
|
|
|
+ def sematic_to_wav(self, indices):
|
|
|
+ model = self.model
|
|
|
+ indices = indices.to(model.device).long()
|
|
|
+ indices = indices.unsqueeze(1).unsqueeze(-1)
|
|
|
+
|
|
|
+ mel_lengths = indices.shape[2] * (
|
|
|
+ model.downsample.total_strides if model.downsample is not None else 1
|
|
|
+ )
|
|
|
+ mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
|
|
|
+ mel_masks = torch.ones(
|
|
|
+ (1, 1, mel_lengths), device=model.device, dtype=torch.float32
|
|
|
+ )
|
|
|
+
|
|
|
+ text_features = model.vq_encoder.decode(indices)
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ f"VQ Encoded, indices: {indices.shape} equivalent to "
|
|
|
+ + f"{1 / (mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
|
|
|
+ )
|
|
|
+
|
|
|
+ text_features = F.interpolate(
|
|
|
+ text_features, size=mel_lengths[0], mode="nearest"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Sample mels
|
|
|
+ decoded_mels = model.decoder(text_features, mel_masks)
|
|
|
+ fake_audios = model.generator(decoded_mels)
|
|
|
+ logger.info(
|
|
|
+ f"Generated audio of shape {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Save audio
|
|
|
+ fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
|
|
|
+
|
|
|
+ return fake_audio, model.sampling_rate
|
|
|
+
|
|
|
+
|
|
|
+class LoadLlamaModelRequest(BaseModel):
|
|
|
+ config_name: str = "text2semantic_finetune"
|
|
|
+ checkpoint_path: str = "checkpoints/text2semantic-400m-v0.2-4k.pth"
|
|
|
+ device: str = "cuda"
|
|
|
+ precision: Literal["float16", "bfloat16"] = "bfloat16"
|
|
|
+ tokenizer: str = "fishaudio/speech-lm-v1"
|
|
|
+ compile: bool = True
|
|
|
+
|
|
|
+
|
|
|
+class LoadVQGANModelRequest(BaseModel):
|
|
|
+ config_name: str = "vqgan_pretrain"
|
|
|
+ checkpoint_path: str = "checkpoints/vqgan-v1.pth"
|
|
|
+
|
|
|
+
|
|
|
+class LoadModelResponse(BaseModel):
|
|
|
+ name: str
|
|
|
+
|
|
|
+
|
|
|
+@routes.http.put("/models/{name}")
|
|
|
+def load_model(
|
|
|
+ name: Annotated[str, Path("default")],
|
|
|
+ llama: Annotated[LoadLlamaModelRequest, Body()],
|
|
|
+ vqgan: Annotated[LoadVQGANModelRequest, Body()],
|
|
|
+) -> Annotated[LoadModelResponse, JSONResponse[200, {}, LoadModelResponse]]:
|
|
|
+ """
|
|
|
+ Load model
|
|
|
+ """
|
|
|
+
|
|
|
+ if name in MODELS:
|
|
|
+ del MODELS[name]
|
|
|
+
|
|
|
+ logger.info("Loading model ...")
|
|
|
+ new_model = {
|
|
|
+ "llama": LlamaModel(
|
|
|
+ config_name=llama.config_name,
|
|
|
+ checkpoint_path=llama.checkpoint_path,
|
|
|
+ device=llama.device,
|
|
|
+ precision=llama.precision,
|
|
|
+ tokenizer_path=llama.tokenizer,
|
|
|
+ compile=llama.compile,
|
|
|
+ ),
|
|
|
+ "vqgan": VQGANModel(
|
|
|
+ config_name=vqgan.config_name,
|
|
|
+ checkpoint_path=vqgan.checkpoint_path,
|
|
|
+ ),
|
|
|
+ }
|
|
|
+
|
|
|
+ MODELS[name] = new_model
|
|
|
+
|
|
|
+ return LoadModelResponse(name=name)
|
|
|
+
|
|
|
+
|
|
|
+@routes.http.delete("/models/{name}")
|
|
|
+def delete_model(
|
|
|
+ name: Annotated[str, Path("default")],
|
|
|
+) -> JSONResponse[200, {}, dict]:
|
|
|
+ """
|
|
|
+ Delete model
|
|
|
+ """
|
|
|
+
|
|
|
+ if name not in MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=HTTPStatus.BAD_REQUEST,
|
|
|
+ content="Model not found.",
|
|
|
+ )
|
|
|
+
|
|
|
+ return JSONResponse(
|
|
|
+ dict(message="Model deleted."),
|
|
|
+ 200,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@routes.http.get("/models")
|
|
|
+def list_models() -> JSONResponse[200, {}, dict]:
|
|
|
+ """
|
|
|
+ List models
|
|
|
+ """
|
|
|
+
|
|
|
+ return JSONResponse(
|
|
|
+ dict(models=list(MODELS.keys())),
|
|
|
+ 200,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class InvokeRequest(BaseModel):
|
|
|
+ text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
|
|
|
+ prompt_text: Optional[str] = None
|
|
|
+ prompt_tokens: Optional[str] = None
|
|
|
+ max_new_tokens: int = 0
|
|
|
+ top_k: Optional[int] = None
|
|
|
+ top_p: float = 0.5
|
|
|
+ repetition_penalty: float = 1.5
|
|
|
+ temperature: float = 0.7
|
|
|
+ use_g2p: bool = True
|
|
|
+ seed: Optional[int] = None
|
|
|
+ speaker: Optional[str] = None
|
|
|
+
|
|
|
+
|
|
|
+@routes.http.post("/models/{name}/invoke")
|
|
|
+def invoke_model(
|
|
|
+ name: Annotated[str, Path("default")],
|
|
|
+ req: Annotated[InvokeRequest, Body(exclusive=True)],
|
|
|
+):
|
|
|
+ """
|
|
|
+ Invoke model and generate audio
|
|
|
+ """
|
|
|
+
|
|
|
+ if name not in MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=HTTPStatus.NOT_FOUND,
|
|
|
+ content="Cannot find model.",
|
|
|
+ )
|
|
|
+
|
|
|
+ model = MODELS[name]
|
|
|
+ llama_model_manager = model["llama"]
|
|
|
+ vqgan_model_manager = model["vqgan"]
|
|
|
+
|
|
|
+ device = llama_model_manager.device
|
|
|
+ seed = req.seed
|
|
|
+ prompt_tokens = req.prompt_tokens
|
|
|
+ logger.info(f"Device: {device}")
|
|
|
+
|
|
|
+ prompt_tokens = (
|
|
|
+ torch.from_numpy(np.load(prompt_tokens)).to(device)
|
|
|
+ if prompt_tokens is not None
|
|
|
+ else None
|
|
|
+ )
|
|
|
+ encoded = encode_tokens(
|
|
|
+ llama_model_manager.tokenizer,
|
|
|
+ req.text,
|
|
|
+ prompt_text=req.prompt_text,
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ bos=True,
|
|
|
+ device=device,
|
|
|
+ use_g2p=req.use_g2p,
|
|
|
+ speaker=req.speaker,
|
|
|
+ )
|
|
|
+ prompt_length = encoded.size(1)
|
|
|
+ logger.info(f"Encoded prompt shape: {encoded.shape}")
|
|
|
+
|
|
|
+ if seed is not None:
|
|
|
+ torch.manual_seed(seed)
|
|
|
+ torch.cuda.manual_seed(seed)
|
|
|
+
|
|
|
+ torch.cuda.synchronize()
|
|
|
+
|
|
|
+ t0 = time.perf_counter()
|
|
|
+ y = generate(
|
|
|
+ model=llama_model_manager.model,
|
|
|
+ prompt=encoded,
|
|
|
+ max_new_tokens=req.max_new_tokens,
|
|
|
+ eos_token_id=llama_model_manager.tokenizer.eos_token_id,
|
|
|
+ precision=llama_model_manager.precision,
|
|
|
+ temperature=req.temperature,
|
|
|
+ top_k=req.top_k,
|
|
|
+ top_p=req.top_p,
|
|
|
+ repetition_penalty=req.repetition_penalty,
|
|
|
+ )
|
|
|
+
|
|
|
+ torch.cuda.synchronize()
|
|
|
+ t = time.perf_counter() - t0
|
|
|
+
|
|
|
+ tokens_generated = y.size(1) - prompt_length
|
|
|
+ tokens_sec = tokens_generated / t
|
|
|
+ logger.info(
|
|
|
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
|
|
|
+ )
|
|
|
+ logger.info(
|
|
|
+ f"Bandwidth achieved: {llama_model_manager.model_size * tokens_sec / 1e9:.02f} GB/s"
|
|
|
+ )
|
|
|
+ logger.info(f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
|
|
|
+
|
|
|
+ codes = y[1:, prompt_length:-1]
|
|
|
+ codes = codes - 2
|
|
|
+ assert (codes >= 0).all(), "Codes should be >= 0"
|
|
|
+
|
|
|
+ # --------------- llama end ------------
|
|
|
+ audio, sr = vqgan_model_manager.sematic_to_wav(codes)
|
|
|
+ # --------------- vqgan end ------------
|
|
|
+
|
|
|
+ buffer = io.BytesIO()
|
|
|
+ sf.write(buffer, audio, sr, format="wav")
|
|
|
+
|
|
|
+ return StreamResponse(
|
|
|
+ iterable=[buffer.getvalue()],
|
|
|
+ headers={
|
|
|
+ "Content-Disposition": "attachment; filename=generated.wav",
|
|
|
+ "Content-Type": "audio/wav",
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+# Define Kui app
|
|
|
+app = Kui(
|
|
|
+ exception_handlers={
|
|
|
+ HTTPException: http_execption_handler,
|
|
|
+ Exception: other_exception_handler,
|
|
|
+ },
|
|
|
+)
|
|
|
+app.router = Router(
|
|
|
+ [],
|
|
|
+ http_middlewares=[
|
|
|
+ app.exception_middleware,
|
|
|
+ allow_cors(),
|
|
|
+ ],
|
|
|
+)
|
|
|
+
|
|
|
+# Swagger UI & routes
|
|
|
+app.router << ("/v1" // routes)
|
|
|
+app.router << ("/docs" // OpenAPI().routes)
|