api_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. import gc
  2. import io
  3. import time
  4. import traceback
  5. from http import HTTPStatus
  6. from threading import Lock
  7. from typing import Annotated, Literal, Optional
  8. import numpy as np
  9. import soundfile as sf
  10. import torch
  11. import torch.nn.functional as F
  12. from hydra import compose, initialize
  13. from hydra.utils import instantiate
  14. from kui.wsgi import (
  15. Body,
  16. HTTPException,
  17. HttpView,
  18. JSONResponse,
  19. Kui,
  20. OpenAPI,
  21. Path,
  22. StreamResponse,
  23. allow_cors,
  24. )
  25. from kui.wsgi.routing import MultimethodRoutes, Router
  26. from loguru import logger
  27. from pydantic import BaseModel
  28. from transformers import AutoTokenizer
  29. import tools.llama.generate
  30. from tools.llama.generate import encode_tokens, generate, load_model
  31. # Define utils for web server
  32. def http_execption_handler(exc: HTTPException):
  33. return JSONResponse(
  34. dict(
  35. statusCode=exc.status_code,
  36. message=exc.content,
  37. error=HTTPStatus(exc.status_code).phrase,
  38. ),
  39. exc.status_code,
  40. exc.headers,
  41. )
  42. def other_exception_handler(exc: "Exception"):
  43. traceback.print_exc()
  44. status = HTTPStatus.INTERNAL_SERVER_ERROR
  45. return JSONResponse(
  46. dict(statusCode=status, message=str(exc), error=status.phrase),
  47. status,
  48. )
  49. routes = MultimethodRoutes(base_class=HttpView)
  50. # Define models
  51. MODELS = {}
  52. class LlamaModel:
  53. def __init__(
  54. self,
  55. config_name: str,
  56. checkpoint_path: str,
  57. device,
  58. precision: str,
  59. tokenizer_path: str,
  60. compile: bool,
  61. ):
  62. self.device = device
  63. self.compile = compile
  64. self.t0 = time.time()
  65. self.precision = torch.bfloat16 if precision == "bfloat16" else torch.float16
  66. self.model = load_model(config_name, checkpoint_path, device, self.precision)
  67. self.model_size = sum(
  68. p.numel() for p in self.model.parameters() if p.requires_grad
  69. )
  70. torch.cuda.synchronize()
  71. logger.info(f"Time to load model: {time.time() - self.t0:.02f} seconds")
  72. self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
  73. if self.compile:
  74. logger.info("Compiling model ...")
  75. tools.llama.generate.decode_one_token = torch.compile(
  76. tools.llama.generate.decode_one_token,
  77. mode="reduce-overhead",
  78. fullgraph=True,
  79. )
  80. def __del__(self):
  81. self.model = None
  82. self.tokenizer = None
  83. gc.collect()
  84. if torch.cuda.is_available():
  85. torch.cuda.empty_cache()
  86. logger.info("The llama is removed from memory.")
  87. class VQGANModel:
  88. def __init__(self, config_name: str, checkpoint_path: str, device: str):
  89. with initialize(version_base="1.3", config_path="../fish_speech/configs"):
  90. self.cfg = compose(config_name=config_name)
  91. self.model = instantiate(self.cfg.model)
  92. state_dict = torch.load(
  93. checkpoint_path,
  94. map_location=self.model.device,
  95. )
  96. if "state_dict" in state_dict:
  97. state_dict = state_dict["state_dict"]
  98. self.model.load_state_dict(state_dict, strict=True)
  99. self.model.eval()
  100. self.model.to(device)
  101. logger.info("Restored VQGAN model from checkpoint")
  102. def __del__(self):
  103. self.cfg = None
  104. self.model = None
  105. gc.collect()
  106. if torch.cuda.is_available():
  107. torch.cuda.empty_cache()
  108. logger.info("The vqgan model is removed from memory.")
  109. @torch.no_grad()
  110. @torch.autocast(device_type="cuda", enabled=True)
  111. def sematic_to_wav(self, indices):
  112. model = self.model
  113. indices = indices.to(model.device).long()
  114. indices = indices.unsqueeze(1).unsqueeze(-1)
  115. mel_lengths = indices.shape[2] * (
  116. model.downsample.total_strides if model.downsample is not None else 1
  117. )
  118. mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
  119. mel_masks = torch.ones(
  120. (1, 1, mel_lengths), device=model.device, dtype=torch.float32
  121. )
  122. text_features = model.vq_encoder.decode(indices)
  123. logger.info(
  124. f"VQ Encoded, indices: {indices.shape} equivalent to "
  125. + f"{1 / (mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
  126. )
  127. text_features = F.interpolate(
  128. text_features, size=mel_lengths[0], mode="nearest"
  129. )
  130. # Sample mels
  131. decoded_mels = model.decoder(text_features, mel_masks)
  132. fake_audios = model.generator(decoded_mels)
  133. logger.info(
  134. f"Generated audio of shape {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
  135. )
  136. # Save audio
  137. fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
  138. return fake_audio, model.sampling_rate
  139. class LoadLlamaModelRequest(BaseModel):
  140. config_name: str = "text2semantic_finetune"
  141. checkpoint_path: str = "checkpoints/text2semantic-400m-v0.2-4k.pth"
  142. precision: Literal["float16", "bfloat16"] = "bfloat16"
  143. tokenizer: str = "fishaudio/speech-lm-v1"
  144. compile: bool = True
  145. class LoadVQGANModelRequest(BaseModel):
  146. config_name: str = "vqgan_pretrain"
  147. checkpoint_path: str = "checkpoints/vqgan-v1.pth"
  148. class LoadModelRequest(BaseModel):
  149. device: str = "cuda"
  150. llama: LoadLlamaModelRequest
  151. vqgan: LoadVQGANModelRequest
  152. class LoadModelResponse(BaseModel):
  153. name: str
  154. @routes.http.put("/models/{name}")
  155. def api_load_model(
  156. name: Annotated[str, Path("default")],
  157. req: Annotated[LoadModelRequest, Body(exclusive=True)],
  158. ) -> Annotated[LoadModelResponse, JSONResponse[200, {}, LoadModelResponse]]:
  159. """
  160. Load model
  161. """
  162. if name in MODELS:
  163. del MODELS[name]
  164. llama = req.llama
  165. vqgan = req.vqgan
  166. logger.info("Loading model ...")
  167. new_model = {
  168. "llama": LlamaModel(
  169. config_name=llama.config_name,
  170. checkpoint_path=llama.checkpoint_path,
  171. device=req.device,
  172. precision=llama.precision,
  173. tokenizer_path=llama.tokenizer,
  174. compile=llama.compile,
  175. ),
  176. "vqgan": VQGANModel(
  177. config_name=vqgan.config_name,
  178. checkpoint_path=vqgan.checkpoint_path,
  179. device=req.device,
  180. ),
  181. "lock": Lock(),
  182. }
  183. MODELS[name] = new_model
  184. return LoadModelResponse(name=name)
  185. @routes.http.delete("/models/{name}")
  186. def api_delete_model(
  187. name: Annotated[str, Path("default")],
  188. ) -> JSONResponse[200, {}, dict]:
  189. """
  190. Delete model
  191. """
  192. if name not in MODELS:
  193. raise HTTPException(
  194. status_code=HTTPStatus.BAD_REQUEST,
  195. content="Model not found.",
  196. )
  197. del MODELS[name]
  198. return JSONResponse(
  199. dict(message="Model deleted."),
  200. 200,
  201. )
  202. @routes.http.get("/models")
  203. def api_list_models() -> JSONResponse[200, {}, dict]:
  204. """
  205. List models
  206. """
  207. return JSONResponse(
  208. dict(models=list(MODELS.keys())),
  209. 200,
  210. )
  211. class InvokeRequest(BaseModel):
  212. text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
  213. prompt_text: Optional[str] = None
  214. prompt_tokens: Optional[str] = None
  215. max_new_tokens: int = 0
  216. top_k: Optional[int] = None
  217. top_p: float = 0.5
  218. repetition_penalty: float = 1.5
  219. temperature: float = 0.7
  220. use_g2p: bool = True
  221. seed: Optional[int] = None
  222. speaker: Optional[str] = None
  223. @routes.http.post("/models/{name}/invoke")
  224. def api_invoke_model(
  225. name: Annotated[str, Path("default")],
  226. req: Annotated[InvokeRequest, Body(exclusive=True)],
  227. ):
  228. """
  229. Invoke model and generate audio
  230. """
  231. if name not in MODELS:
  232. raise HTTPException(
  233. status_code=HTTPStatus.NOT_FOUND,
  234. content="Cannot find model.",
  235. )
  236. model = MODELS[name]
  237. llama_model_manager = model["llama"]
  238. vqgan_model_manager = model["vqgan"]
  239. # Lock
  240. model["lock"].acquire()
  241. device = llama_model_manager.device
  242. seed = req.seed
  243. prompt_tokens = req.prompt_tokens
  244. logger.info(f"Device: {device}")
  245. prompt_tokens = (
  246. torch.from_numpy(np.load(prompt_tokens)).to(device)
  247. if prompt_tokens is not None
  248. else None
  249. )
  250. encoded = encode_tokens(
  251. llama_model_manager.tokenizer,
  252. req.text,
  253. prompt_text=req.prompt_text,
  254. prompt_tokens=prompt_tokens,
  255. bos=True,
  256. device=device,
  257. use_g2p=req.use_g2p,
  258. speaker=req.speaker,
  259. )
  260. prompt_length = encoded.size(1)
  261. logger.info(f"Encoded prompt shape: {encoded.shape}")
  262. if seed is not None:
  263. torch.manual_seed(seed)
  264. torch.cuda.manual_seed(seed)
  265. torch.cuda.synchronize()
  266. t0 = time.perf_counter()
  267. y = generate(
  268. model=llama_model_manager.model,
  269. prompt=encoded,
  270. max_new_tokens=req.max_new_tokens,
  271. eos_token_id=llama_model_manager.tokenizer.eos_token_id,
  272. precision=llama_model_manager.precision,
  273. temperature=req.temperature,
  274. top_k=req.top_k,
  275. top_p=req.top_p,
  276. repetition_penalty=req.repetition_penalty,
  277. )
  278. torch.cuda.synchronize()
  279. t = time.perf_counter() - t0
  280. tokens_generated = y.size(1) - prompt_length
  281. tokens_sec = tokens_generated / t
  282. logger.info(
  283. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  284. )
  285. logger.info(
  286. f"Bandwidth achieved: {llama_model_manager.model_size * tokens_sec / 1e9:.02f} GB/s"
  287. )
  288. logger.info(f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
  289. codes = y[1:, prompt_length:-1]
  290. codes = codes - 2
  291. assert (codes >= 0).all(), "Codes should be >= 0"
  292. # Release lock
  293. model["lock"].release()
  294. # --------------- llama end ------------
  295. audio, sr = vqgan_model_manager.sematic_to_wav(codes)
  296. # --------------- vqgan end ------------
  297. buffer = io.BytesIO()
  298. sf.write(buffer, audio, sr, format="wav")
  299. return StreamResponse(
  300. iterable=[buffer.getvalue()],
  301. headers={
  302. "Content-Disposition": "attachment; filename=audio.wav",
  303. "Content-Type": "application/octet-stream",
  304. },
  305. )
  306. # Define Kui app
  307. app = Kui(
  308. exception_handlers={
  309. HTTPException: http_execption_handler,
  310. Exception: other_exception_handler,
  311. },
  312. )
  313. app.router = Router(
  314. [],
  315. http_middlewares=[
  316. app.exception_middleware,
  317. allow_cors(),
  318. ],
  319. )
  320. # Swagger UI & routes
  321. app.router << ("/v1" // routes)
  322. app.router << ("/docs" // OpenAPI().routes)