api_server.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. import gc
  2. import io
  3. import time
  4. import traceback
  5. from http import HTTPStatus
  6. from threading import Lock
  7. from typing import Annotated, Literal, Optional
  8. import librosa
  9. import numpy as np
  10. import soundfile as sf
  11. import torch
  12. import torch.nn.functional as F
  13. from hydra import compose, initialize
  14. from hydra.utils import instantiate
  15. from kui.wsgi import (
  16. Body,
  17. HTTPException,
  18. HttpView,
  19. JSONResponse,
  20. Kui,
  21. OpenAPI,
  22. Path,
  23. StreamResponse,
  24. allow_cors,
  25. )
  26. from kui.wsgi.routing import MultimethodRoutes, Router
  27. from loguru import logger
  28. from pydantic import BaseModel
  29. from transformers import AutoTokenizer
  30. import tools.llama.generate
  31. from fish_speech.models.vqgan.utils import sequence_mask
  32. from tools.llama.generate import encode_tokens, generate, load_model
  33. # Define utils for web server
  34. def http_execption_handler(exc: HTTPException):
  35. return JSONResponse(
  36. dict(
  37. statusCode=exc.status_code,
  38. message=exc.content,
  39. error=HTTPStatus(exc.status_code).phrase,
  40. ),
  41. exc.status_code,
  42. exc.headers,
  43. )
  44. def other_exception_handler(exc: "Exception"):
  45. traceback.print_exc()
  46. status = HTTPStatus.INTERNAL_SERVER_ERROR
  47. return JSONResponse(
  48. dict(statusCode=status, message=str(exc), error=status.phrase),
  49. status,
  50. )
  51. routes = MultimethodRoutes(base_class=HttpView)
  52. # Define models
  53. MODELS = {}
  54. class LlamaModel:
  55. def __init__(
  56. self,
  57. config_name: str,
  58. checkpoint_path: str,
  59. device,
  60. precision: str,
  61. tokenizer_path: str,
  62. compile: bool,
  63. ):
  64. self.device = device
  65. self.compile = compile
  66. self.t0 = time.time()
  67. self.precision = torch.bfloat16 if precision == "bfloat16" else torch.float16
  68. self.model = load_model(config_name, checkpoint_path, device, self.precision)
  69. self.model_size = sum(
  70. p.numel() for p in self.model.parameters() if p.requires_grad
  71. )
  72. torch.cuda.synchronize()
  73. logger.info(f"Time to load model: {time.time() - self.t0:.02f} seconds")
  74. self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
  75. if self.compile:
  76. logger.info("Compiling model ...")
  77. tools.llama.generate.decode_one_token = torch.compile(
  78. tools.llama.generate.decode_one_token,
  79. mode="reduce-overhead",
  80. fullgraph=True,
  81. )
  82. def __del__(self):
  83. self.model = None
  84. self.tokenizer = None
  85. gc.collect()
  86. if torch.cuda.is_available():
  87. torch.cuda.empty_cache()
  88. logger.info("The llama is removed from memory.")
  89. class VQGANModel:
  90. def __init__(self, config_name: str, checkpoint_path: str, device: str):
  91. with initialize(version_base="1.3", config_path="../fish_speech/configs"):
  92. self.cfg = compose(config_name=config_name)
  93. self.model = instantiate(self.cfg.model)
  94. state_dict = torch.load(
  95. checkpoint_path,
  96. map_location=self.model.device,
  97. )
  98. if "state_dict" in state_dict:
  99. state_dict = state_dict["state_dict"]
  100. self.model.load_state_dict(state_dict, strict=True)
  101. self.model.eval()
  102. self.model.to(device)
  103. logger.info("Restored VQGAN model from checkpoint")
  104. def __del__(self):
  105. self.cfg = None
  106. self.model = None
  107. gc.collect()
  108. if torch.cuda.is_available():
  109. torch.cuda.empty_cache()
  110. logger.info("The vqgan model is removed from memory.")
  111. @torch.no_grad()
  112. def sematic_to_wav(self, indices):
  113. model = self.model
  114. indices = indices.to(model.device).long()
  115. indices = indices.unsqueeze(1).unsqueeze(-1)
  116. mel_lengths = indices.shape[2] * (
  117. model.downsample.total_strides if model.downsample is not None else 1
  118. )
  119. mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
  120. mel_masks = torch.ones(
  121. (1, 1, mel_lengths), device=model.device, dtype=torch.float32
  122. )
  123. text_features = model.vq_encoder.decode(indices)
  124. logger.info(
  125. f"VQ Encoded, indices: {indices.shape} equivalent to "
  126. + f"{1 / (mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
  127. )
  128. text_features = F.interpolate(
  129. text_features, size=mel_lengths[0], mode="nearest"
  130. )
  131. # Sample mels
  132. decoded_mels = model.decoder(text_features, mel_masks)
  133. fake_audios = model.generator(decoded_mels)
  134. logger.info(
  135. f"Generated audio of shape {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
  136. )
  137. # Save audio
  138. fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
  139. return fake_audio, model.sampling_rate
  140. @torch.no_grad()
  141. def wav_to_semantic(self, audio):
  142. model = self.model
  143. # Load audio
  144. audio, _ = librosa.load(
  145. audio,
  146. sr=model.sampling_rate,
  147. mono=True,
  148. )
  149. audios = torch.from_numpy(audio).to(model.device)[None, None, :]
  150. logger.info(
  151. f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
  152. )
  153. # VQ Encoder
  154. audio_lengths = torch.tensor(
  155. [audios.shape[2]], device=model.device, dtype=torch.long
  156. )
  157. features = gt_mels = model.mel_transform(
  158. audios, sample_rate=model.sampling_rate
  159. )
  160. if model.downsample is not None:
  161. features = model.downsample(features)
  162. mel_lengths = audio_lengths // model.hop_length
  163. feature_lengths = (
  164. audio_lengths
  165. / model.hop_length
  166. / (model.downsample.total_strides if model.downsample is not None else 1)
  167. ).long()
  168. feature_masks = torch.unsqueeze(
  169. sequence_mask(feature_lengths, features.shape[2]), 1
  170. ).to(gt_mels.dtype)
  171. # vq_features is 50 hz, need to convert to true mel size
  172. text_features = model.mel_encoder(features, feature_masks)
  173. _, indices, _ = model.vq_encoder(text_features, feature_masks)
  174. if indices.ndim == 4 and indices.shape[1] == 1 and indices.shape[3] == 1:
  175. indices = indices[:, 0, :, 0]
  176. else:
  177. logger.error(f"Unknown indices shape: {indices.shape}")
  178. return
  179. logger.info(f"Generated indices of shape {indices.shape}")
  180. return indices
  181. class LoadLlamaModelRequest(BaseModel):
  182. config_name: str = "text2semantic_finetune"
  183. checkpoint_path: str = "checkpoints/text2semantic-400m-v0.2-4k.pth"
  184. precision: Literal["float16", "bfloat16"] = "bfloat16"
  185. tokenizer: str = "fishaudio/speech-lm-v1"
  186. compile: bool = True
  187. class LoadVQGANModelRequest(BaseModel):
  188. config_name: str = "vqgan_pretrain"
  189. checkpoint_path: str = "checkpoints/vqgan-v1.pth"
  190. class LoadModelRequest(BaseModel):
  191. device: str = "cuda"
  192. llama: LoadLlamaModelRequest
  193. vqgan: LoadVQGANModelRequest
  194. class LoadModelResponse(BaseModel):
  195. name: str
  196. @routes.http.put("/models/{name}")
  197. def api_load_model(
  198. name: Annotated[str, Path("default")],
  199. req: Annotated[LoadModelRequest, Body(exclusive=True)],
  200. ) -> Annotated[LoadModelResponse, JSONResponse[200, {}, LoadModelResponse]]:
  201. """
  202. Load model
  203. """
  204. if name in MODELS:
  205. del MODELS[name]
  206. llama = req.llama
  207. vqgan = req.vqgan
  208. logger.info("Loading model ...")
  209. new_model = {
  210. "llama": LlamaModel(
  211. config_name=llama.config_name,
  212. checkpoint_path=llama.checkpoint_path,
  213. device=req.device,
  214. precision=llama.precision,
  215. tokenizer_path=llama.tokenizer,
  216. compile=llama.compile,
  217. ),
  218. "vqgan": VQGANModel(
  219. config_name=vqgan.config_name,
  220. checkpoint_path=vqgan.checkpoint_path,
  221. device=req.device,
  222. ),
  223. "lock": Lock(),
  224. }
  225. MODELS[name] = new_model
  226. return LoadModelResponse(name=name)
  227. @routes.http.delete("/models/{name}")
  228. def api_delete_model(
  229. name: Annotated[str, Path("default")],
  230. ) -> JSONResponse[200, {}, dict]:
  231. """
  232. Delete model
  233. """
  234. if name not in MODELS:
  235. raise HTTPException(
  236. status_code=HTTPStatus.BAD_REQUEST,
  237. content="Model not found.",
  238. )
  239. del MODELS[name]
  240. return JSONResponse(
  241. dict(message="Model deleted."),
  242. 200,
  243. )
  244. @routes.http.get("/models")
  245. def api_list_models() -> JSONResponse[200, {}, dict]:
  246. """
  247. List models
  248. """
  249. return JSONResponse(
  250. dict(models=list(MODELS.keys())),
  251. 200,
  252. )
  253. class InvokeRequest(BaseModel):
  254. text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
  255. prompt_text: Optional[str] = None
  256. prompt_tokens: Optional[str] = None
  257. max_new_tokens: int = 0
  258. top_k: Optional[int] = None
  259. top_p: float = 0.5
  260. repetition_penalty: float = 1.5
  261. temperature: float = 0.7
  262. order: str = "zh,jp,en"
  263. use_g2p: bool = True
  264. seed: Optional[int] = None
  265. speaker: Optional[str] = None
  266. @routes.http.post("/models/{name}/invoke")
  267. def api_invoke_model(
  268. name: Annotated[str, Path("default")],
  269. req: Annotated[InvokeRequest, Body(exclusive=True)],
  270. ):
  271. """
  272. Invoke model and generate audio
  273. """
  274. if name not in MODELS:
  275. raise HTTPException(
  276. status_code=HTTPStatus.NOT_FOUND,
  277. content="Cannot find model.",
  278. )
  279. model = MODELS[name]
  280. llama_model_manager = model["llama"]
  281. vqgan_model_manager = model["vqgan"]
  282. device = llama_model_manager.device
  283. seed = req.seed
  284. prompt_tokens = req.prompt_tokens
  285. logger.info(f"Device: {device}")
  286. if prompt_tokens is not None and prompt_tokens.endswith(".npy"):
  287. prompt_tokens = torch.from_numpy(np.load(prompt_tokens)).to(device)
  288. elif prompt_tokens is not None and prompt_tokens.endswith(".wav"):
  289. prompt_tokens = vqgan_model_manager.wav_to_semantic(prompt_tokens)
  290. elif prompt_tokens is not None:
  291. logger.error(f"Unknown prompt tokens: {prompt_tokens}")
  292. raise HTTPException(
  293. status_code=HTTPStatus.BAD_REQUEST,
  294. content="Unknown prompt tokens, it should be either .npy or .wav file.",
  295. )
  296. else:
  297. prompt_tokens = None
  298. # Lock
  299. model["lock"].acquire()
  300. encoded = encode_tokens(
  301. llama_model_manager.tokenizer,
  302. req.text,
  303. prompt_text=req.prompt_text,
  304. prompt_tokens=prompt_tokens,
  305. bos=True,
  306. device=device,
  307. use_g2p=req.use_g2p,
  308. speaker=req.speaker,
  309. order=req.order,
  310. )
  311. prompt_length = encoded.size(1)
  312. logger.info(f"Encoded prompt shape: {encoded.shape}")
  313. if seed is not None:
  314. torch.manual_seed(seed)
  315. torch.cuda.manual_seed(seed)
  316. torch.cuda.synchronize()
  317. t0 = time.perf_counter()
  318. y = generate(
  319. model=llama_model_manager.model,
  320. prompt=encoded,
  321. max_new_tokens=req.max_new_tokens,
  322. eos_token_id=llama_model_manager.tokenizer.eos_token_id,
  323. precision=llama_model_manager.precision,
  324. temperature=req.temperature,
  325. top_k=req.top_k,
  326. top_p=req.top_p,
  327. repetition_penalty=req.repetition_penalty,
  328. )
  329. torch.cuda.synchronize()
  330. t = time.perf_counter() - t0
  331. tokens_generated = y.size(1) - prompt_length
  332. tokens_sec = tokens_generated / t
  333. logger.info(
  334. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  335. )
  336. logger.info(
  337. f"Bandwidth achieved: {llama_model_manager.model_size * tokens_sec / 1e9:.02f} GB/s"
  338. )
  339. logger.info(f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
  340. codes = y[1:, prompt_length:-1]
  341. codes = codes - 2
  342. assert (codes >= 0).all(), "Codes should be >= 0"
  343. # Release lock
  344. model["lock"].release()
  345. # --------------- llama end ------------
  346. audio, sr = vqgan_model_manager.sematic_to_wav(codes)
  347. # --------------- vqgan end ------------
  348. buffer = io.BytesIO()
  349. sf.write(buffer, audio, sr, format="wav")
  350. return StreamResponse(
  351. iterable=[buffer.getvalue()],
  352. headers={
  353. "Content-Disposition": "attachment; filename=audio.wav",
  354. "Content-Type": "application/octet-stream",
  355. },
  356. )
  357. # Define Kui app
  358. app = Kui(
  359. exception_handlers={
  360. HTTPException: http_execption_handler,
  361. Exception: other_exception_handler,
  362. },
  363. )
  364. app.router = Router(
  365. [],
  366. http_middlewares=[
  367. app.exception_middleware,
  368. allow_cors(),
  369. ],
  370. )
  371. # Swagger UI & routes
  372. app.router << ("/v1" // routes)
  373. app.router << ("/docs" // OpenAPI().routes)