inference.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. import traceback
  6. from contextlib import nullcontext
  7. from dataclasses import dataclass
  8. from pathlib import Path
  9. from typing import Literal, Optional, Tuple, Union
  10. import click
  11. import numpy as np
  12. import torch
  13. import torch._inductor.config
  14. from loguru import logger
  15. from tqdm import tqdm
  16. from transformers import AutoTokenizer
  17. from fish_speech.content_sequence import (
  18. ContentSequence,
  19. TextPart,
  20. VQPart,
  21. )
  22. from fish_speech.tokenizer import IM_END_TOKEN
  23. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  24. torch._inductor.config.coordinate_descent_tuning = True
  25. torch._inductor.config.triton.unique_kernel_names = True
  26. if hasattr(torch._inductor.config, "fx_graph_cache"):
  27. # Experimental feature to reduce compilation times, will be on by default in future
  28. torch._inductor.config.fx_graph_cache = True
  29. from torch.nn.attention import SDPBackend, sdpa_kernel
  30. from fish_speech.models.text2semantic.llama import (
  31. BaseTransformer,
  32. DualARTransformer,
  33. NaiveTransformer,
  34. )
  35. def multinomial_sample_one_no_sync(
  36. probs_sort,
  37. ): # Does multinomial sampling without a cuda synchronization
  38. q = torch.empty_like(probs_sort).exponential_(1)
  39. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  40. def logits_to_probs(
  41. logits,
  42. temperature: torch.Tensor,
  43. top_p: torch.Tensor,
  44. repetition_penalty: torch.Tensor,
  45. previous_tokens: Optional[torch.Tensor] = None,
  46. ) -> torch.Tensor:
  47. # Apply repetition penalty
  48. if previous_tokens is not None:
  49. previous_tokens = previous_tokens.long()
  50. score = torch.gather(logits, dim=-1, index=previous_tokens)
  51. score = torch.where(
  52. score < 0, score * repetition_penalty, score / repetition_penalty
  53. )
  54. logits.scatter_(dim=-1, index=previous_tokens, src=score)
  55. # Apply top-p sampling
  56. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  57. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  58. sorted_indices_to_remove = cum_probs > top_p
  59. sorted_indices_to_remove[0] = False # keep at least one option
  60. indices_to_remove = sorted_indices_to_remove.scatter(
  61. dim=-1, index=sorted_indices, src=sorted_indices_to_remove
  62. )
  63. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  64. logits = logits / torch.clip(temperature, min=1e-5)
  65. probs = torch.nn.functional.softmax(logits, dim=-1)
  66. return probs
  67. def sample(
  68. logits,
  69. temperature: torch.Tensor,
  70. top_p: torch.Tensor,
  71. repetition_penalty: torch.Tensor,
  72. previous_tokens: Optional[torch.Tensor] = None,
  73. ) -> Tuple[torch.Tensor, torch.Tensor]:
  74. probs = logits_to_probs(
  75. logits=logits[0, -1],
  76. temperature=temperature,
  77. top_p=top_p,
  78. repetition_penalty=repetition_penalty,
  79. previous_tokens=previous_tokens,
  80. )
  81. idx_next = multinomial_sample_one_no_sync(probs)
  82. return idx_next, probs
  83. def decode_one_token_ar(
  84. model: DualARTransformer,
  85. x: torch.Tensor,
  86. input_pos: torch.Tensor,
  87. temperature: torch.Tensor,
  88. top_p: torch.Tensor,
  89. repetition_penalty: torch.Tensor,
  90. audio_masks: torch.Tensor,
  91. audio_parts: torch.Tensor,
  92. previous_tokens: torch.Tensor = None,
  93. ) -> torch.Tensor:
  94. # print(x, torch.count_nonzero(vq_masks))
  95. x = model.forward_generate(
  96. x,
  97. input_pos,
  98. audio_masks=audio_masks,
  99. audio_parts=audio_parts,
  100. )
  101. logits = x.logits # [:, -1:]
  102. hidden_states = x.hidden_states # [:, -1:]
  103. codebooks = [
  104. sample(
  105. logits,
  106. temperature=temperature,
  107. top_p=top_p,
  108. repetition_penalty=repetition_penalty,
  109. previous_tokens=(
  110. previous_tokens[:, 0] if previous_tokens is not None else None
  111. ),
  112. )[0]
  113. ]
  114. # Cleanup the cache
  115. for layer in model.fast_layers:
  116. layer.attention.kv_cache.k_cache.fill_(0)
  117. layer.attention.kv_cache.v_cache.fill_(0)
  118. input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
  119. model.forward_generate_fast(hidden_states, input_pos)
  120. a = codebooks[0] - model.tokenizer.semantic_begin_id
  121. a[a < 0] = 0
  122. hidden_states = model.fast_embeddings(a)
  123. codebooks.append(a)
  124. for codebook_idx in range(1, model.config.num_codebooks):
  125. input_pos = torch.tensor(
  126. [codebook_idx], device=hidden_states.device, dtype=torch.long
  127. )
  128. logits = model.forward_generate_fast(hidden_states, input_pos)
  129. short_logits = logits[:, :, :1024]
  130. # Convert logits to probs
  131. a = sample(
  132. short_logits,
  133. temperature=temperature,
  134. top_p=top_p,
  135. repetition_penalty=repetition_penalty,
  136. previous_tokens=(
  137. previous_tokens[codebook_idx + 1]
  138. if previous_tokens is not None
  139. else None
  140. ),
  141. )[0]
  142. hidden_states = model.fast_embeddings(a)
  143. codebooks.append(a)
  144. codebooks = torch.stack(codebooks, dim=1)
  145. return codebooks.T
  146. def decode_n_tokens(
  147. model: NaiveTransformer,
  148. cur_token: torch.Tensor,
  149. input_pos: torch.Tensor,
  150. num_new_tokens: int,
  151. temperature: torch.Tensor,
  152. top_p: torch.Tensor,
  153. repetition_penalty: torch.Tensor,
  154. audio_masks: torch.Tensor,
  155. audio_parts: torch.Tensor,
  156. decode_one_token=decode_one_token_ar,
  157. ):
  158. previous_tokens = torch.zeros(
  159. (model.config.num_codebooks + 1, model.config.max_seq_len),
  160. dtype=torch.int,
  161. device=cur_token.device,
  162. )
  163. for i in tqdm(range(num_new_tokens)):
  164. # We need to get windowed repeat penalty
  165. win_size = 16
  166. if i < win_size:
  167. window = previous_tokens[:, :win_size]
  168. else:
  169. window = previous_tokens[:, i - win_size : i]
  170. with sdpa_kernel(
  171. SDPBackend.MATH
  172. ): # Actually better for Inductor to codegen attention here
  173. next_token = decode_one_token(
  174. model=model,
  175. x=cur_token,
  176. input_pos=input_pos,
  177. previous_tokens=window,
  178. temperature=temperature,
  179. top_p=top_p,
  180. repetition_penalty=repetition_penalty,
  181. audio_masks=audio_masks,
  182. audio_parts=audio_parts,
  183. ).clone()
  184. input_pos += 1
  185. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  186. previous_tokens[:, i : i + 1] = next_token.view(
  187. model.config.num_codebooks + 1, -1
  188. )
  189. if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
  190. break
  191. return previous_tokens[:, : i + 1]
  192. @torch.no_grad()
  193. @torch.inference_mode()
  194. def generate(
  195. *,
  196. model: BaseTransformer,
  197. prompt: torch.Tensor,
  198. max_new_tokens: int,
  199. audio_masks: torch.Tensor,
  200. audio_parts: torch.Tensor,
  201. decode_one_token=decode_one_token_ar,
  202. num_samples: int = 1,
  203. **sampling_kwargs,
  204. ):
  205. """
  206. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  207. """
  208. # create an empty tensor of the expected final shape and fill in the current tokens
  209. T = prompt.size(1)
  210. prompt = prompt[None].repeat(num_samples, 1, 1)
  211. if T >= model.config.max_seq_len:
  212. raise ValueError(
  213. f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
  214. )
  215. if max_new_tokens:
  216. if T + max_new_tokens > model.config.max_seq_len:
  217. max_new_tokens = model.config.max_seq_len - T
  218. T_new = T + max_new_tokens
  219. else:
  220. T_new = model.config.max_seq_len
  221. max_new_tokens = T_new - T
  222. device, dtype = prompt.device, prompt.dtype
  223. with torch.device(device):
  224. model.setup_caches(
  225. max_batch_size=num_samples,
  226. max_seq_len=model.config.max_seq_len,
  227. dtype=next(model.parameters()).dtype,
  228. )
  229. codebook_dim = 1 + model.config.num_codebooks
  230. input_pos = torch.arange(0, T, device=device)
  231. empty = torch.empty(
  232. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  233. )
  234. empty[:, :T] = prompt
  235. seq = empty
  236. temperature = torch.tensor(
  237. sampling_kwargs["temperature"], device=device, dtype=torch.bfloat16
  238. )
  239. top_p = torch.tensor(sampling_kwargs["top_p"], device=device, dtype=torch.bfloat16)
  240. repetition_penalty = torch.tensor(
  241. sampling_kwargs["repetition_penalty"], device=device, dtype=torch.bfloat16
  242. )
  243. prefill_decode = decode_one_token_ar
  244. first_token = prefill_decode(
  245. model,
  246. prompt.view(1, codebook_dim, -1),
  247. input_pos,
  248. temperature,
  249. top_p,
  250. repetition_penalty,
  251. audio_masks,
  252. audio_parts,
  253. )
  254. seq[:, T : T + 1] = first_token
  255. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  256. x = decode_n_tokens(
  257. model,
  258. first_token.view(1, codebook_dim, -1),
  259. input_pos,
  260. max_new_tokens - 1,
  261. temperature=temperature,
  262. top_p=top_p,
  263. repetition_penalty=repetition_penalty,
  264. audio_masks=audio_masks,
  265. audio_parts=audio_parts,
  266. decode_one_token=decode_one_token,
  267. )
  268. seq = seq[:, : T + 1 + x.size(1)]
  269. seq[:, T + 1 :] = x
  270. return seq
  271. def init_model(checkpoint_path, device, precision, compile=False):
  272. model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
  273. model = model.to(device=device, dtype=precision)
  274. logger.info(f"Restored model from checkpoint")
  275. if isinstance(model, DualARTransformer):
  276. decode_one_token = decode_one_token_ar
  277. prefill_n_tokens = decode_one_token_ar
  278. logger.info("Using DualARTransformer")
  279. else:
  280. raise ValueError("Unsupported model type")
  281. # Initialize cache
  282. with torch.device(device):
  283. model.setup_caches(
  284. max_batch_size=1,
  285. max_seq_len=model.config.max_seq_len,
  286. dtype=next(model.parameters()).dtype,
  287. )
  288. if compile:
  289. logger.info("Compiling function...")
  290. decode_one_token = torch.compile(
  291. decode_one_token,
  292. # mode="max-autotune-no-cudagraphs",
  293. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  294. mode="reduce-overhead" if torch.cuda.is_available() else None,
  295. fullgraph=True,
  296. )
  297. return model.eval(), decode_one_token
  298. @dataclass
  299. class GenerateResponse:
  300. action: Literal["sample", "next"]
  301. codes: Optional[torch.Tensor] = None
  302. text: Optional[str] = None
  303. def generate_long(
  304. *,
  305. model,
  306. device: str | torch.device,
  307. decode_one_token: callable,
  308. text: str,
  309. num_samples: int = 1,
  310. max_new_tokens: int = 0,
  311. top_p: int = 0.8,
  312. repetition_penalty: float = 1.1,
  313. temperature: float = 0.8,
  314. compile: bool = False,
  315. iterative_prompt: bool = True,
  316. chunk_length: int = 512,
  317. prompt_text: Optional[str | list[str]] = None,
  318. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  319. ):
  320. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  321. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  322. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  323. use_prompt = prompt_text is not None and prompt_tokens is not None
  324. if use_prompt and isinstance(prompt_text, str):
  325. prompt_text = [prompt_text]
  326. prompt_tokens = [prompt_tokens]
  327. assert use_prompt is False or len(prompt_text) == len(
  328. prompt_tokens
  329. ), "Prompt text and tokens must have the same length"
  330. prompt_tokens = [i.cpu() for i in prompt_tokens]
  331. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  332. tokenizer = model.tokenizer
  333. base_content_sequence = ContentSequence(modality="interleave")
  334. max_length = model.config.max_seq_len
  335. if use_prompt:
  336. for t, c in zip(prompt_text, prompt_tokens):
  337. base_content_sequence.append(
  338. [
  339. TextPart(text=t),
  340. VQPart(codes=c),
  341. ],
  342. add_end=True,
  343. speaker=0,
  344. )
  345. base_content_sequence.append(
  346. [
  347. TextPart(text=text),
  348. ],
  349. add_end=False,
  350. speaker=0,
  351. )
  352. encoded, audio_masks, audio_parts = base_content_sequence.encode_for_inference(
  353. tokenizer, num_codebooks=model.config.num_codebooks
  354. )
  355. if encoded.size(1) > max_length - 2048:
  356. raise ValueError(f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}")
  357. encoded = encoded.to(device=device)
  358. logger.info(f"Encoded text: {text}")
  359. # Move temperature, top_p, repetition_penalty to device
  360. # This is important so that changing params doesn't trigger recompile
  361. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  362. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  363. repetition_penalty = torch.tensor(
  364. repetition_penalty, device=device, dtype=torch.float
  365. )
  366. for sample_idx in range(num_samples):
  367. if torch.cuda.is_available():
  368. torch.cuda.synchronize()
  369. global_encoded = []
  370. seg_idx = 0
  371. prompt_length = encoded.size(1)
  372. t0 = time.perf_counter()
  373. y = generate(
  374. model=model,
  375. prompt=encoded,
  376. max_new_tokens=max_new_tokens,
  377. audio_masks=audio_masks,
  378. audio_parts=audio_parts,
  379. decode_one_token=decode_one_token,
  380. temperature=temperature,
  381. top_p=top_p,
  382. repetition_penalty=repetition_penalty,
  383. )
  384. if sample_idx == 0 and seg_idx == 0 and compile:
  385. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  386. if torch.cuda.is_available():
  387. torch.cuda.synchronize()
  388. t = time.perf_counter() - t0
  389. tokens_generated = y.size(1) - prompt_length
  390. tokens_sec = tokens_generated / t
  391. logger.info(
  392. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  393. )
  394. logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
  395. if torch.cuda.is_available():
  396. logger.info(
  397. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  398. )
  399. # Put the generated tokens
  400. # since there is <im_end>, we remove last token
  401. codes = y[1:, prompt_length:-1].clone()
  402. assert (codes >= 0).all(), f"Negative code found"
  403. decoded = y[:, prompt_length:].clone()
  404. # But for global encoding, we should keep the <im_end> token
  405. global_encoded.append(decoded.cpu())
  406. assert (codes >= 0).all(), f"Negative code found: {codes}"
  407. yield GenerateResponse(action="sample", codes=codes, text=text)
  408. seg_idx += 1
  409. # This indicates the end of the current sample
  410. yield GenerateResponse(action="next")
  411. @dataclass
  412. class WrappedGenerateResponse:
  413. status: Literal["success", "error"]
  414. response: Optional[GenerateResponse | Exception] = None
  415. @dataclass
  416. class GenerateRequest:
  417. request: dict
  418. response_queue: queue.Queue
  419. def launch_thread_safe_queue(
  420. checkpoint_path,
  421. device,
  422. precision,
  423. compile: bool = False,
  424. ):
  425. input_queue = queue.Queue()
  426. init_event = threading.Event()
  427. def worker():
  428. model, decode_one_token = init_model(
  429. checkpoint_path, device, precision, compile=compile
  430. )
  431. with torch.device(device):
  432. model.setup_caches(
  433. max_batch_size=1,
  434. max_seq_len=model.config.max_seq_len,
  435. dtype=next(model.parameters()).dtype,
  436. )
  437. init_event.set()
  438. while True:
  439. item: GenerateRequest | None = input_queue.get()
  440. if item is None:
  441. break
  442. kwargs = item.request
  443. response_queue = item.response_queue
  444. try:
  445. for chunk in generate_long(
  446. model=model, decode_one_token=decode_one_token, **kwargs
  447. ):
  448. response_queue.put(
  449. WrappedGenerateResponse(status="success", response=chunk)
  450. )
  451. except Exception as e:
  452. logger.error(traceback.format_exc())
  453. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  454. threading.Thread(target=worker, daemon=True).start()
  455. init_event.wait()
  456. return input_queue
  457. @click.command()
  458. @click.option(
  459. "--text",
  460. type=str,
  461. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  462. )
  463. @click.option("--prompt-text", type=str, default=None, multiple=True)
  464. @click.option(
  465. "--prompt-tokens",
  466. type=click.Path(path_type=Path, exists=True),
  467. default=None,
  468. multiple=True,
  469. )
  470. @click.option("--num-samples", type=int, default=1)
  471. @click.option("--max-new-tokens", type=int, default=0)
  472. @click.option("--top-p", type=float, default=0.8)
  473. @click.option("--repetition-penalty", type=float, default=1.1)
  474. @click.option("--temperature", type=float, default=0.8)
  475. @click.option(
  476. "--checkpoint-path",
  477. type=click.Path(path_type=Path, exists=True),
  478. default="checkpoints/openaudio-s1-mini",
  479. )
  480. @click.option("--device", type=str, default="cuda")
  481. @click.option("--compile/--no-compile", default=False)
  482. @click.option("--seed", type=int, default=42)
  483. @click.option("--half/--no-half", default=False)
  484. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  485. @click.option("--chunk-length", type=int, default=300)
  486. @click.option("--output-dir", type=Path, default="temp")
  487. def main(
  488. text: str,
  489. prompt_text: Optional[list[str]],
  490. prompt_tokens: Optional[list[Path]],
  491. num_samples: int,
  492. max_new_tokens: int,
  493. top_p: int,
  494. repetition_penalty: float,
  495. temperature: float,
  496. checkpoint_path: Path,
  497. device: str,
  498. compile: bool,
  499. seed: int,
  500. half: bool,
  501. iterative_prompt: bool,
  502. chunk_length: int,
  503. output_dir: Path,
  504. ) -> None:
  505. os.makedirs(output_dir, exist_ok=True)
  506. precision = torch.half if half else torch.bfloat16
  507. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  508. raise ValueError(
  509. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  510. )
  511. logger.info("Loading model ...")
  512. t0 = time.time()
  513. model, decode_one_token = init_model(
  514. checkpoint_path, device, precision, compile=compile
  515. )
  516. with torch.device(device):
  517. model.setup_caches(
  518. max_batch_size=1,
  519. max_seq_len=model.config.max_seq_len,
  520. dtype=next(model.parameters()).dtype,
  521. )
  522. if torch.cuda.is_available():
  523. torch.cuda.synchronize()
  524. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  525. if prompt_tokens is not None:
  526. prompt_tokens = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
  527. torch.manual_seed(seed)
  528. if torch.cuda.is_available():
  529. torch.cuda.manual_seed(seed)
  530. generator = generate_long(
  531. model=model,
  532. device=device,
  533. decode_one_token=decode_one_token,
  534. text=text,
  535. num_samples=num_samples,
  536. max_new_tokens=max_new_tokens,
  537. top_p=top_p,
  538. repetition_penalty=repetition_penalty,
  539. temperature=temperature,
  540. compile=compile,
  541. iterative_prompt=iterative_prompt,
  542. chunk_length=chunk_length,
  543. prompt_text=prompt_text,
  544. prompt_tokens=prompt_tokens,
  545. )
  546. idx = 0
  547. codes = []
  548. for response in generator:
  549. if response.action == "sample":
  550. codes.append(response.codes)
  551. logger.info(f"Sampled text: {response.text}")
  552. elif response.action == "next":
  553. if codes:
  554. codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
  555. np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
  556. logger.info(f"Saved codes to {codes_npy_path}")
  557. logger.info(f"Next sample")
  558. codes = []
  559. idx += 1
  560. else:
  561. logger.error(f"Error: {response}")
  562. if __name__ == "__main__":
  563. main()