api_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. import gc
  2. import io
  3. import time
  4. import traceback
  5. from http import HTTPStatus
  6. from threading import Lock
  7. from typing import Annotated, Literal, Optional
  8. import librosa
  9. import numpy as np
  10. import soundfile as sf
  11. import torch
  12. import torch.nn.functional as F
  13. from hydra import compose, initialize
  14. from hydra.utils import instantiate
  15. from kui.wsgi import (
  16. Body,
  17. HTTPException,
  18. HttpView,
  19. JSONResponse,
  20. Kui,
  21. OpenAPI,
  22. Path,
  23. StreamResponse,
  24. allow_cors,
  25. )
  26. from kui.wsgi.routing import MultimethodRoutes, Router
  27. from loguru import logger
  28. from pydantic import BaseModel
  29. from transformers import AutoTokenizer
  30. import tools.llama.generate
  31. from fish_speech.models.vqgan.utils import sequence_mask
  32. from tools.llama.generate import encode_tokens, generate, load_model
  33. # Define utils for web server
  34. def http_execption_handler(exc: HTTPException):
  35. return JSONResponse(
  36. dict(
  37. statusCode=exc.status_code,
  38. message=exc.content,
  39. error=HTTPStatus(exc.status_code).phrase,
  40. ),
  41. exc.status_code,
  42. exc.headers,
  43. )
  44. def other_exception_handler(exc: "Exception"):
  45. traceback.print_exc()
  46. status = HTTPStatus.INTERNAL_SERVER_ERROR
  47. return JSONResponse(
  48. dict(statusCode=status, message=str(exc), error=status.phrase),
  49. status,
  50. )
  51. routes = MultimethodRoutes(base_class=HttpView)
  52. # Define models
  53. MODELS = {}
  54. class LlamaModel:
  55. def __init__(
  56. self,
  57. config_name: str,
  58. checkpoint_path: str,
  59. device,
  60. precision: str,
  61. tokenizer_path: str,
  62. compile: bool,
  63. ):
  64. self.device = device
  65. self.compile = compile
  66. self.t0 = time.time()
  67. self.precision = torch.bfloat16 if precision == "bfloat16" else torch.float16
  68. self.model = load_model(config_name, checkpoint_path, device, self.precision)
  69. self.model_size = sum(
  70. p.numel() for p in self.model.parameters() if p.requires_grad
  71. )
  72. torch.cuda.synchronize()
  73. logger.info(f"Time to load model: {time.time() - self.t0:.02f} seconds")
  74. self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
  75. if self.compile:
  76. logger.info("Compiling model ...")
  77. tools.llama.generate.decode_one_token = torch.compile(
  78. tools.llama.generate.decode_one_token,
  79. mode="reduce-overhead",
  80. fullgraph=True,
  81. )
  82. def __del__(self):
  83. self.model = None
  84. self.tokenizer = None
  85. gc.collect()
  86. if torch.cuda.is_available():
  87. torch.cuda.empty_cache()
  88. logger.info("The llama is removed from memory.")
  89. class VQGANModel:
  90. def __init__(self, config_name: str, checkpoint_path: str, device: str):
  91. with initialize(version_base="1.3", config_path="../fish_speech/configs"):
  92. self.cfg = compose(config_name=config_name)
  93. self.model = instantiate(self.cfg.model)
  94. state_dict = torch.load(
  95. checkpoint_path,
  96. map_location=self.model.device,
  97. )
  98. if "state_dict" in state_dict:
  99. state_dict = state_dict["state_dict"]
  100. self.model.load_state_dict(state_dict, strict=True)
  101. self.model.eval()
  102. self.model.to(device)
  103. logger.info("Restored VQGAN model from checkpoint")
  104. def __del__(self):
  105. self.cfg = None
  106. self.model = None
  107. gc.collect()
  108. if torch.cuda.is_available():
  109. torch.cuda.empty_cache()
  110. logger.info("The vqgan model is removed from memory.")
  111. @torch.no_grad()
  112. def sematic_to_wav(self, indices):
  113. model = self.model
  114. indices = indices.to(model.device).long()
  115. feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
  116. decoded = model.decode(indices=indices[None], feature_lengths=feature_lengths)
  117. # Save audio
  118. fake_audio = decoded.audios[0, 0].cpu().numpy().astype(np.float32)
  119. return fake_audio, model.sampling_rate
  120. @torch.no_grad()
  121. def wav_to_semantic(self, audio):
  122. model = self.model
  123. # Load audio
  124. audio, _ = librosa.load(
  125. audio,
  126. sr=model.sampling_rate,
  127. mono=True,
  128. )
  129. audios = torch.from_numpy(audio).to(model.device)[None, None, :]
  130. logger.info(
  131. f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
  132. )
  133. # VQ Encoder
  134. audio_lengths = torch.tensor(
  135. [audios.shape[2]], device=model.device, dtype=torch.long
  136. )
  137. encoded = model.encode(audios, audio_lengths)
  138. indices = encoded.indices[0]
  139. logger.info(f"Generated indices of shape {indices.shape}")
  140. return indices
  141. class LoadLlamaModelRequest(BaseModel):
  142. config_name: str = "text2semantic_finetune"
  143. checkpoint_path: str = "checkpoints/text2semantic-400m-v0.2-4k.pth"
  144. precision: Literal["float16", "bfloat16"] = "bfloat16"
  145. tokenizer: str = "fishaudio/speech-lm-v1"
  146. compile: bool = True
  147. class LoadVQGANModelRequest(BaseModel):
  148. config_name: str = "vqgan_pretrain"
  149. checkpoint_path: str = "checkpoints/vqgan-v1.pth"
  150. class LoadModelRequest(BaseModel):
  151. device: str = "cuda"
  152. llama: LoadLlamaModelRequest
  153. vqgan: LoadVQGANModelRequest
  154. class LoadModelResponse(BaseModel):
  155. name: str
  156. @routes.http.put("/models/{name}")
  157. def api_load_model(
  158. name: Annotated[str, Path("default")],
  159. req: Annotated[LoadModelRequest, Body(exclusive=True)],
  160. ) -> Annotated[LoadModelResponse, JSONResponse[200, {}, LoadModelResponse]]:
  161. """
  162. Load model
  163. """
  164. if name in MODELS:
  165. del MODELS[name]
  166. llama = req.llama
  167. vqgan = req.vqgan
  168. logger.info("Loading model ...")
  169. new_model = {
  170. "llama": LlamaModel(
  171. config_name=llama.config_name,
  172. checkpoint_path=llama.checkpoint_path,
  173. device=req.device,
  174. precision=llama.precision,
  175. tokenizer_path=llama.tokenizer,
  176. compile=llama.compile,
  177. ),
  178. "vqgan": VQGANModel(
  179. config_name=vqgan.config_name,
  180. checkpoint_path=vqgan.checkpoint_path,
  181. device=req.device,
  182. ),
  183. "lock": Lock(),
  184. }
  185. MODELS[name] = new_model
  186. return LoadModelResponse(name=name)
  187. @routes.http.delete("/models/{name}")
  188. def api_delete_model(
  189. name: Annotated[str, Path("default")],
  190. ) -> JSONResponse[200, {}, dict]:
  191. """
  192. Delete model
  193. """
  194. if name not in MODELS:
  195. raise HTTPException(
  196. status_code=HTTPStatus.BAD_REQUEST,
  197. content="Model not found.",
  198. )
  199. del MODELS[name]
  200. return JSONResponse(
  201. dict(message="Model deleted."),
  202. 200,
  203. )
  204. @routes.http.get("/models")
  205. def api_list_models() -> JSONResponse[200, {}, dict]:
  206. """
  207. List models
  208. """
  209. return JSONResponse(
  210. dict(models=list(MODELS.keys())),
  211. 200,
  212. )
  213. class InvokeRequest(BaseModel):
  214. text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
  215. prompt_text: Optional[str] = None
  216. prompt_tokens: Optional[str] = None
  217. max_new_tokens: int = 0
  218. top_k: Optional[int] = None
  219. top_p: float = 0.5
  220. repetition_penalty: float = 1.5
  221. temperature: float = 0.7
  222. order: str = "zh,jp,en"
  223. use_g2p: bool = True
  224. seed: Optional[int] = None
  225. speaker: Optional[str] = None
  226. @routes.http.post("/models/{name}/invoke")
  227. def api_invoke_model(
  228. name: Annotated[str, Path("default")],
  229. req: Annotated[InvokeRequest, Body(exclusive=True)],
  230. ):
  231. """
  232. Invoke model and generate audio
  233. """
  234. if name not in MODELS:
  235. raise HTTPException(
  236. status_code=HTTPStatus.NOT_FOUND,
  237. content="Cannot find model.",
  238. )
  239. model = MODELS[name]
  240. llama_model_manager = model["llama"]
  241. vqgan_model_manager = model["vqgan"]
  242. device = llama_model_manager.device
  243. seed = req.seed
  244. prompt_tokens = req.prompt_tokens
  245. logger.info(f"Device: {device}")
  246. if prompt_tokens is not None and prompt_tokens.endswith(".npy"):
  247. prompt_tokens = torch.from_numpy(np.load(prompt_tokens)).to(device)
  248. elif prompt_tokens is not None and prompt_tokens.endswith(".wav"):
  249. prompt_tokens = vqgan_model_manager.wav_to_semantic(prompt_tokens)
  250. elif prompt_tokens is not None:
  251. logger.error(f"Unknown prompt tokens: {prompt_tokens}")
  252. raise HTTPException(
  253. status_code=HTTPStatus.BAD_REQUEST,
  254. content="Unknown prompt tokens, it should be either .npy or .wav file.",
  255. )
  256. else:
  257. prompt_tokens = None
  258. # Lock
  259. model["lock"].acquire()
  260. encoded = encode_tokens(
  261. llama_model_manager.tokenizer,
  262. req.text,
  263. prompt_text=req.prompt_text,
  264. prompt_tokens=prompt_tokens,
  265. bos=True,
  266. device=device,
  267. use_g2p=req.use_g2p,
  268. speaker=req.speaker,
  269. order=req.order,
  270. )
  271. prompt_length = encoded.size(1)
  272. logger.info(f"Encoded prompt shape: {encoded.shape}")
  273. if seed is not None:
  274. torch.manual_seed(seed)
  275. torch.cuda.manual_seed(seed)
  276. torch.cuda.synchronize()
  277. t0 = time.perf_counter()
  278. y = generate(
  279. model=llama_model_manager.model,
  280. prompt=encoded,
  281. max_new_tokens=req.max_new_tokens,
  282. eos_token_id=llama_model_manager.tokenizer.eos_token_id,
  283. precision=llama_model_manager.precision,
  284. temperature=req.temperature,
  285. top_k=req.top_k,
  286. top_p=req.top_p,
  287. repetition_penalty=req.repetition_penalty,
  288. )
  289. torch.cuda.synchronize()
  290. t = time.perf_counter() - t0
  291. tokens_generated = y.size(1) - prompt_length
  292. tokens_sec = tokens_generated / t
  293. logger.info(
  294. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  295. )
  296. logger.info(
  297. f"Bandwidth achieved: {llama_model_manager.model_size * tokens_sec / 1e9:.02f} GB/s"
  298. )
  299. logger.info(f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
  300. codes = y[1:, prompt_length:-1]
  301. codes = codes - 2
  302. assert (codes >= 0).all(), "Codes should be >= 0"
  303. # Release lock
  304. model["lock"].release()
  305. # --------------- llama end ------------
  306. audio, sr = vqgan_model_manager.sematic_to_wav(codes)
  307. # --------------- vqgan end ------------
  308. buffer = io.BytesIO()
  309. sf.write(buffer, audio, sr, format="wav")
  310. return StreamResponse(
  311. iterable=[buffer.getvalue()],
  312. headers={
  313. "Content-Disposition": "attachment; filename=audio.wav",
  314. "Content-Type": "application/octet-stream",
  315. },
  316. )
  317. # Define Kui app
  318. app = Kui(
  319. exception_handlers={
  320. HTTPException: http_execption_handler,
  321. Exception: other_exception_handler,
  322. },
  323. )
  324. app.router = Router(
  325. [],
  326. http_middlewares=[
  327. app.exception_middleware,
  328. allow_cors(),
  329. ],
  330. )
  331. # Swagger UI & routes
  332. app.router << ("/v1" // routes)
  333. app.router << ("/docs" // OpenAPI().routes)