api_server.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. import gc
  2. import io
  3. import time
  4. import traceback
  5. from http import HTTPStatus
  6. from typing import Annotated, Any, Literal, Optional
  7. import numpy as np
  8. import soundfile as sf
  9. import torch
  10. import torch.nn.functional as F
  11. from hydra import compose, initialize
  12. from hydra.utils import instantiate
  13. from kui.wsgi import (
  14. Body,
  15. HTTPException,
  16. HttpView,
  17. JSONResponse,
  18. Kui,
  19. OpenAPI,
  20. Path,
  21. StreamResponse,
  22. allow_cors,
  23. )
  24. from kui.wsgi.routing import MultimethodRoutes, Router
  25. from loguru import logger
  26. from pydantic import BaseModel
  27. from transformers import AutoTokenizer
  28. import tools.llama.generate
  29. from tools.llama.generate import encode_tokens, generate, load_model
  30. # Define utils for web server
  31. def http_execption_handler(exc: HTTPException):
  32. return JSONResponse(
  33. dict(
  34. statusCode=exc.status_code,
  35. message=exc.content,
  36. error=HTTPStatus(exc.status_code).phrase,
  37. ),
  38. exc.status_code,
  39. exc.headers,
  40. )
  41. def other_exception_handler(exc: "Exception"):
  42. traceback.print_exc()
  43. status = HTTPStatus.INTERNAL_SERVER_ERROR
  44. return JSONResponse(
  45. dict(statusCode=status, message=str(exc), error=status.phrase),
  46. status,
  47. )
  48. routes = MultimethodRoutes(base_class=HttpView)
  49. # Define models
  50. MODELS = {}
  51. class LlamaModel:
  52. def __init__(
  53. self,
  54. config_name: str,
  55. checkpoint_path: str,
  56. device,
  57. precision: str,
  58. tokenizer_path: str,
  59. compile: bool,
  60. ):
  61. self.device = device
  62. self.compile = compile
  63. self.t0 = time.time()
  64. self.precision = torch.bfloat16 if precision == "bfloat16" else torch.float16
  65. self.model = load_model(config_name, checkpoint_path, device, self.precision)
  66. self.model_size = sum(
  67. p.numel() for p in self.model.parameters() if p.requires_grad
  68. )
  69. torch.cuda.synchronize()
  70. logger.info(f"Time to load model: {time.time() - self.t0:.02f} seconds")
  71. if self.tokenizer is None:
  72. self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
  73. if self.compile:
  74. logger.info("Compiling model ...")
  75. tools.llama.generate.decode_one_token = torch.compile(
  76. tools.llama.generate.decode_one_token,
  77. mode="reduce-overhead",
  78. fullgraph=True,
  79. )
  80. def __del__(self):
  81. self.model = None
  82. self.tokenizer = None
  83. gc.collect()
  84. if torch.cuda.is_available():
  85. torch.cuda.empty_cache()
  86. logger.info("The llama is removed from memory.")
  87. class VQGANModel:
  88. def __init__(self, config_name: str, checkpoint_path: str):
  89. if self.cfg is None:
  90. with initialize(version_base="1.3", config_path="../fish_speech/configs"):
  91. self.cfg = compose(config_name=config_name)
  92. self.model = instantiate(self.cfg.model)
  93. state_dict = torch.load(
  94. checkpoint_path,
  95. map_location=self.model.device,
  96. )
  97. if "state_dict" in state_dict:
  98. state_dict = state_dict["state_dict"]
  99. self.model.load_state_dict(state_dict, strict=True)
  100. self.model.eval()
  101. self.model.cuda()
  102. logger.info("Restored model from checkpoint")
  103. def __del__(self):
  104. self.cfg = None
  105. self.model = None
  106. gc.collect()
  107. if torch.cuda.is_available():
  108. torch.cuda.empty_cache()
  109. logger.info("The vqgan model is removed from memory.")
  110. @torch.no_grad()
  111. @torch.autocast(device_type="cuda", enabled=True)
  112. def sematic_to_wav(self, indices):
  113. model = self.model
  114. indices = indices.to(model.device).long()
  115. indices = indices.unsqueeze(1).unsqueeze(-1)
  116. mel_lengths = indices.shape[2] * (
  117. model.downsample.total_strides if model.downsample is not None else 1
  118. )
  119. mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
  120. mel_masks = torch.ones(
  121. (1, 1, mel_lengths), device=model.device, dtype=torch.float32
  122. )
  123. text_features = model.vq_encoder.decode(indices)
  124. logger.info(
  125. f"VQ Encoded, indices: {indices.shape} equivalent to "
  126. + f"{1 / (mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
  127. )
  128. text_features = F.interpolate(
  129. text_features, size=mel_lengths[0], mode="nearest"
  130. )
  131. # Sample mels
  132. decoded_mels = model.decoder(text_features, mel_masks)
  133. fake_audios = model.generator(decoded_mels)
  134. logger.info(
  135. f"Generated audio of shape {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
  136. )
  137. # Save audio
  138. fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
  139. return fake_audio, model.sampling_rate
  140. class LoadLlamaModelRequest(BaseModel):
  141. config_name: str = "text2semantic_finetune"
  142. checkpoint_path: str = "checkpoints/text2semantic-400m-v0.2-4k.pth"
  143. device: str = "cuda"
  144. precision: Literal["float16", "bfloat16"] = "bfloat16"
  145. tokenizer: str = "fishaudio/speech-lm-v1"
  146. compile: bool = True
  147. class LoadVQGANModelRequest(BaseModel):
  148. config_name: str = "vqgan_pretrain"
  149. checkpoint_path: str = "checkpoints/vqgan-v1.pth"
  150. class LoadModelResponse(BaseModel):
  151. name: str
  152. @routes.http.put("/models/{name}")
  153. def load_model(
  154. name: Annotated[str, Path("default")],
  155. llama: Annotated[LoadLlamaModelRequest, Body()],
  156. vqgan: Annotated[LoadVQGANModelRequest, Body()],
  157. ) -> Annotated[LoadModelResponse, JSONResponse[200, {}, LoadModelResponse]]:
  158. """
  159. Load model
  160. """
  161. if name in MODELS:
  162. del MODELS[name]
  163. logger.info("Loading model ...")
  164. new_model = {
  165. "llama": LlamaModel(
  166. config_name=llama.config_name,
  167. checkpoint_path=llama.checkpoint_path,
  168. device=llama.device,
  169. precision=llama.precision,
  170. tokenizer_path=llama.tokenizer,
  171. compile=llama.compile,
  172. ),
  173. "vqgan": VQGANModel(
  174. config_name=vqgan.config_name,
  175. checkpoint_path=vqgan.checkpoint_path,
  176. ),
  177. }
  178. MODELS[name] = new_model
  179. return LoadModelResponse(name=name)
  180. @routes.http.delete("/models/{name}")
  181. def delete_model(
  182. name: Annotated[str, Path("default")],
  183. ) -> JSONResponse[200, {}, dict]:
  184. """
  185. Delete model
  186. """
  187. if name not in MODELS:
  188. raise HTTPException(
  189. status_code=HTTPStatus.BAD_REQUEST,
  190. content="Model not found.",
  191. )
  192. return JSONResponse(
  193. dict(message="Model deleted."),
  194. 200,
  195. )
  196. @routes.http.get("/models")
  197. def list_models() -> JSONResponse[200, {}, dict]:
  198. """
  199. List models
  200. """
  201. return JSONResponse(
  202. dict(models=list(MODELS.keys())),
  203. 200,
  204. )
  205. class InvokeRequest(BaseModel):
  206. text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
  207. prompt_text: Optional[str] = None
  208. prompt_tokens: Optional[str] = None
  209. max_new_tokens: int = 0
  210. top_k: Optional[int] = None
  211. top_p: float = 0.5
  212. repetition_penalty: float = 1.5
  213. temperature: float = 0.7
  214. use_g2p: bool = True
  215. seed: Optional[int] = None
  216. speaker: Optional[str] = None
  217. @routes.http.post("/models/{name}/invoke")
  218. def invoke_model(
  219. name: Annotated[str, Path("default")],
  220. req: Annotated[InvokeRequest, Body(exclusive=True)],
  221. ):
  222. """
  223. Invoke model and generate audio
  224. """
  225. if name not in MODELS:
  226. raise HTTPException(
  227. status_code=HTTPStatus.NOT_FOUND,
  228. content="Cannot find model.",
  229. )
  230. model = MODELS[name]
  231. llama_model_manager = model["llama"]
  232. vqgan_model_manager = model["vqgan"]
  233. device = llama_model_manager.device
  234. seed = req.seed
  235. prompt_tokens = req.prompt_tokens
  236. logger.info(f"Device: {device}")
  237. prompt_tokens = (
  238. torch.from_numpy(np.load(prompt_tokens)).to(device)
  239. if prompt_tokens is not None
  240. else None
  241. )
  242. encoded = encode_tokens(
  243. llama_model_manager.tokenizer,
  244. req.text,
  245. prompt_text=req.prompt_text,
  246. prompt_tokens=prompt_tokens,
  247. bos=True,
  248. device=device,
  249. use_g2p=req.use_g2p,
  250. speaker=req.speaker,
  251. )
  252. prompt_length = encoded.size(1)
  253. logger.info(f"Encoded prompt shape: {encoded.shape}")
  254. if seed is not None:
  255. torch.manual_seed(seed)
  256. torch.cuda.manual_seed(seed)
  257. torch.cuda.synchronize()
  258. t0 = time.perf_counter()
  259. y = generate(
  260. model=llama_model_manager.model,
  261. prompt=encoded,
  262. max_new_tokens=req.max_new_tokens,
  263. eos_token_id=llama_model_manager.tokenizer.eos_token_id,
  264. precision=llama_model_manager.precision,
  265. temperature=req.temperature,
  266. top_k=req.top_k,
  267. top_p=req.top_p,
  268. repetition_penalty=req.repetition_penalty,
  269. )
  270. torch.cuda.synchronize()
  271. t = time.perf_counter() - t0
  272. tokens_generated = y.size(1) - prompt_length
  273. tokens_sec = tokens_generated / t
  274. logger.info(
  275. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  276. )
  277. logger.info(
  278. f"Bandwidth achieved: {llama_model_manager.model_size * tokens_sec / 1e9:.02f} GB/s"
  279. )
  280. logger.info(f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
  281. codes = y[1:, prompt_length:-1]
  282. codes = codes - 2
  283. assert (codes >= 0).all(), "Codes should be >= 0"
  284. # --------------- llama end ------------
  285. audio, sr = vqgan_model_manager.sematic_to_wav(codes)
  286. # --------------- vqgan end ------------
  287. buffer = io.BytesIO()
  288. sf.write(buffer, audio, sr, format="wav")
  289. return StreamResponse(
  290. iterable=[buffer.getvalue()],
  291. headers={
  292. "Content-Disposition": "attachment; filename=generated.wav",
  293. "Content-Type": "audio/wav",
  294. },
  295. )
  296. # Define Kui app
  297. app = Kui(
  298. exception_handlers={
  299. HTTPException: http_execption_handler,
  300. Exception: other_exception_handler,
  301. },
  302. )
  303. app.router = Router(
  304. [],
  305. http_middlewares=[
  306. app.exception_middleware,
  307. allow_cors(),
  308. ],
  309. )
  310. # Swagger UI & routes
  311. app.router << ("/v1" // routes)
  312. app.router << ("/docs" // OpenAPI().routes)