inference.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. import traceback
  6. from contextlib import nullcontext
  7. from dataclasses import dataclass
  8. from pathlib import Path
  9. from typing import Literal, Optional, Tuple, Union
  10. import click
  11. import numpy as np
  12. import torch
  13. import torch._inductor.config
  14. from loguru import logger
  15. from tqdm import tqdm
  16. from transformers import AutoTokenizer
  17. from fish_speech.content_sequence import (
  18. ContentSequence,
  19. TextPart,
  20. VQPart,
  21. )
  22. from fish_speech.text import split_text
  23. from fish_speech.tokenizer import IM_END_TOKEN
  24. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  25. torch._inductor.config.coordinate_descent_tuning = True
  26. torch._inductor.config.triton.unique_kernel_names = True
  27. if hasattr(torch._inductor.config, "fx_graph_cache"):
  28. # Experimental feature to reduce compilation times, will be on by default in future
  29. torch._inductor.config.fx_graph_cache = True
  30. from torch.nn.attention import SDPBackend, sdpa_kernel
  31. from fish_speech.models.text2semantic.llama import (
  32. BaseTransformer,
  33. DualARTransformer,
  34. NaiveTransformer,
  35. )
  36. def multinomial_sample_one_no_sync(
  37. probs_sort,
  38. ): # Does multinomial sampling without a cuda synchronization
  39. q = torch.empty_like(probs_sort).exponential_(1)
  40. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  41. def logits_to_probs(
  42. logits,
  43. temperature: torch.Tensor,
  44. top_p: torch.Tensor,
  45. repetition_penalty: torch.Tensor,
  46. previous_tokens: Optional[torch.Tensor] = None,
  47. ) -> torch.Tensor:
  48. # Apply repetition penalty
  49. if previous_tokens is not None:
  50. previous_tokens = previous_tokens.long()
  51. score = torch.gather(logits, dim=-1, index=previous_tokens)
  52. score = torch.where(
  53. score < 0, score * repetition_penalty, score / repetition_penalty
  54. )
  55. logits.scatter_(dim=-1, index=previous_tokens, src=score)
  56. # Apply top-p sampling
  57. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  58. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  59. sorted_indices_to_remove = cum_probs > top_p
  60. sorted_indices_to_remove[0] = False # keep at least one option
  61. indices_to_remove = sorted_indices_to_remove.scatter(
  62. dim=-1, index=sorted_indices, src=sorted_indices_to_remove
  63. )
  64. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  65. logits = logits / torch.clip(temperature, min=1e-5)
  66. probs = torch.nn.functional.softmax(logits, dim=-1)
  67. return probs
  68. def sample(
  69. logits,
  70. temperature: torch.Tensor,
  71. top_p: torch.Tensor,
  72. repetition_penalty: torch.Tensor,
  73. previous_tokens: Optional[torch.Tensor] = None,
  74. ) -> Tuple[torch.Tensor, torch.Tensor]:
  75. probs = logits_to_probs(
  76. logits=logits[0, -1],
  77. temperature=temperature,
  78. top_p=top_p,
  79. repetition_penalty=repetition_penalty,
  80. previous_tokens=previous_tokens,
  81. )
  82. idx_next = multinomial_sample_one_no_sync(probs)
  83. return idx_next, probs
  84. def decode_one_token_ar(
  85. model: DualARTransformer,
  86. x: torch.Tensor,
  87. input_pos: torch.Tensor,
  88. temperature: torch.Tensor,
  89. top_p: torch.Tensor,
  90. repetition_penalty: torch.Tensor,
  91. audio_masks: torch.Tensor,
  92. audio_parts: torch.Tensor,
  93. previous_tokens: torch.Tensor = None,
  94. ) -> torch.Tensor:
  95. # print(x, torch.count_nonzero(vq_masks))
  96. x = model.forward_generate(
  97. x,
  98. input_pos,
  99. audio_masks=audio_masks,
  100. audio_parts=audio_parts,
  101. )
  102. logits = x.logits # [:, -1:]
  103. hidden_states = x.hidden_states # [:, -1:]
  104. codebooks = [
  105. sample(
  106. logits,
  107. temperature=temperature,
  108. top_p=top_p,
  109. repetition_penalty=repetition_penalty,
  110. previous_tokens=(
  111. previous_tokens[:, 0] if previous_tokens is not None else None
  112. ),
  113. )[0]
  114. ]
  115. # Cleanup the cache
  116. for layer in model.fast_layers:
  117. layer.attention.kv_cache.k_cache.fill_(0)
  118. layer.attention.kv_cache.v_cache.fill_(0)
  119. input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
  120. model.forward_generate_fast(hidden_states, input_pos)
  121. a = codebooks[0] - model.tokenizer.semantic_begin_id
  122. a[a < 0] = 0
  123. hidden_states = model.fast_embeddings(a)
  124. codebooks.append(a)
  125. for codebook_idx in range(1, model.config.num_codebooks):
  126. input_pos = torch.tensor(
  127. [codebook_idx], device=hidden_states.device, dtype=torch.long
  128. )
  129. logits = model.forward_generate_fast(hidden_states, input_pos)
  130. short_logits = logits[:, :, :1024]
  131. # Convert logits to probs
  132. a = sample(
  133. short_logits,
  134. temperature=temperature,
  135. top_p=top_p,
  136. repetition_penalty=repetition_penalty,
  137. previous_tokens=(
  138. previous_tokens[codebook_idx + 1]
  139. if previous_tokens is not None
  140. else None
  141. ),
  142. )[0]
  143. hidden_states = model.fast_embeddings(a)
  144. codebooks.append(a)
  145. codebooks = torch.stack(codebooks, dim=1)
  146. return codebooks.T
  147. def decode_n_tokens(
  148. model: NaiveTransformer,
  149. cur_token: torch.Tensor,
  150. input_pos: torch.Tensor,
  151. num_new_tokens: int,
  152. temperature: torch.Tensor,
  153. top_p: torch.Tensor,
  154. repetition_penalty: torch.Tensor,
  155. audio_masks: torch.Tensor,
  156. audio_parts: torch.Tensor,
  157. decode_one_token=decode_one_token_ar,
  158. ):
  159. previous_tokens = torch.zeros(
  160. (model.config.num_codebooks + 1, model.config.max_seq_len),
  161. dtype=torch.int,
  162. device=cur_token.device,
  163. )
  164. for i in tqdm(range(num_new_tokens)):
  165. # We need to get windowed repeat penalty
  166. win_size = 16
  167. if i < win_size:
  168. window = previous_tokens[:, :win_size]
  169. else:
  170. window = previous_tokens[:, i - win_size : i]
  171. with sdpa_kernel(
  172. SDPBackend.MATH
  173. ): # Actually better for Inductor to codegen attention here
  174. next_token = decode_one_token(
  175. model=model,
  176. x=cur_token,
  177. input_pos=input_pos,
  178. previous_tokens=window,
  179. temperature=temperature,
  180. top_p=top_p,
  181. repetition_penalty=repetition_penalty,
  182. audio_masks=audio_masks,
  183. audio_parts=audio_parts,
  184. ).clone()
  185. input_pos += 1
  186. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  187. previous_tokens[:, i : i + 1] = next_token.view(
  188. model.config.num_codebooks + 1, -1
  189. )
  190. if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
  191. break
  192. return previous_tokens[:, : i + 1]
  193. @torch.no_grad()
  194. @torch.inference_mode()
  195. def generate(
  196. *,
  197. model: BaseTransformer,
  198. prompt: torch.Tensor,
  199. max_new_tokens: int,
  200. audio_masks: torch.Tensor,
  201. audio_parts: torch.Tensor,
  202. decode_one_token=decode_one_token_ar,
  203. num_samples: int = 1,
  204. **sampling_kwargs,
  205. ):
  206. """
  207. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  208. """
  209. # create an empty tensor of the expected final shape and fill in the current tokens
  210. T = prompt.size(1)
  211. prompt = prompt[None].repeat(num_samples, 1, 1)
  212. if T >= model.config.max_seq_len:
  213. raise ValueError(
  214. f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
  215. )
  216. if max_new_tokens:
  217. if T + max_new_tokens > model.config.max_seq_len:
  218. max_new_tokens = model.config.max_seq_len - T
  219. T_new = T + max_new_tokens
  220. else:
  221. T_new = model.config.max_seq_len
  222. max_new_tokens = T_new - T
  223. device, dtype = prompt.device, prompt.dtype
  224. with torch.device(device):
  225. model.setup_caches(
  226. max_batch_size=num_samples,
  227. max_seq_len=model.config.max_seq_len,
  228. dtype=next(model.parameters()).dtype,
  229. )
  230. codebook_dim = 1 + model.config.num_codebooks
  231. input_pos = torch.arange(0, T, device=device)
  232. empty = torch.empty(
  233. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  234. )
  235. empty[:, :T] = prompt
  236. seq = empty
  237. temperature = torch.tensor(
  238. sampling_kwargs["temperature"], device=device, dtype=torch.bfloat16
  239. )
  240. top_p = torch.tensor(sampling_kwargs["top_p"], device=device, dtype=torch.bfloat16)
  241. repetition_penalty = torch.tensor(
  242. sampling_kwargs["repetition_penalty"], device=device, dtype=torch.bfloat16
  243. )
  244. prefill_decode = decode_one_token_ar
  245. first_token = prefill_decode(
  246. model,
  247. prompt.view(1, codebook_dim, -1),
  248. input_pos,
  249. temperature,
  250. top_p,
  251. repetition_penalty,
  252. audio_masks,
  253. audio_parts,
  254. )
  255. seq[:, T : T + 1] = first_token
  256. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  257. x = decode_n_tokens(
  258. model,
  259. first_token.view(1, codebook_dim, -1),
  260. input_pos,
  261. max_new_tokens - 1,
  262. temperature=temperature,
  263. top_p=top_p,
  264. repetition_penalty=repetition_penalty,
  265. audio_masks=audio_masks,
  266. audio_parts=audio_parts,
  267. decode_one_token=decode_one_token,
  268. )
  269. seq = seq[:, : T + 1 + x.size(1)]
  270. seq[:, T + 1 :] = x
  271. return seq
  272. def init_model(checkpoint_path, device, precision, compile=False):
  273. model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
  274. model = model.to(device=device, dtype=precision)
  275. logger.info(f"Restored model from checkpoint")
  276. if isinstance(model, DualARTransformer):
  277. decode_one_token = decode_one_token_ar
  278. prefill_n_tokens = decode_one_token_ar
  279. logger.info("Using DualARTransformer")
  280. else:
  281. raise ValueError("Unsupported model type")
  282. # Initialize cache
  283. with torch.device(device):
  284. model.setup_caches(
  285. max_batch_size=1,
  286. max_seq_len=model.config.max_seq_len,
  287. dtype=next(model.parameters()).dtype,
  288. )
  289. if compile:
  290. logger.info("Compiling function...")
  291. decode_one_token = torch.compile(
  292. decode_one_token,
  293. # mode="max-autotune-no-cudagraphs",
  294. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  295. mode="reduce-overhead" if torch.cuda.is_available() else None,
  296. fullgraph=True,
  297. )
  298. return model.eval(), decode_one_token
  299. @dataclass
  300. class GenerateResponse:
  301. action: Literal["sample", "next"]
  302. codes: Optional[torch.Tensor] = None
  303. text: Optional[str] = None
  304. def generate_long(
  305. *,
  306. model,
  307. device: str | torch.device,
  308. decode_one_token: callable,
  309. text: str,
  310. num_samples: int = 1,
  311. max_new_tokens: int = 0,
  312. top_p: int = 0.8,
  313. repetition_penalty: float = 1.1,
  314. temperature: float = 0.8,
  315. compile: bool = False,
  316. iterative_prompt: bool = True,
  317. chunk_length: int = 512,
  318. prompt_text: Optional[str | list[str]] = None,
  319. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  320. ):
  321. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  322. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  323. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  324. use_prompt = prompt_text is not None and prompt_tokens is not None
  325. if use_prompt and isinstance(prompt_text, str):
  326. prompt_text = [prompt_text]
  327. prompt_tokens = [prompt_tokens]
  328. assert use_prompt is False or len(prompt_text) == len(
  329. prompt_tokens
  330. ), "Prompt text and tokens must have the same length"
  331. prompt_tokens = [i.cpu() for i in prompt_tokens]
  332. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  333. tokenizer = model.tokenizer
  334. base_content_sequence = ContentSequence(modality="interleave")
  335. max_length = model.config.max_seq_len
  336. if use_prompt:
  337. for t, c in zip(prompt_text, prompt_tokens):
  338. base_content_sequence.append(
  339. [
  340. TextPart(text=t),
  341. VQPart(codes=c),
  342. ],
  343. add_end=True,
  344. speaker=0,
  345. )
  346. base_content_sequence.append(
  347. [
  348. TextPart(text=text),
  349. ],
  350. add_end=False,
  351. speaker=0,
  352. )
  353. encoded, audio_masks, audio_parts = base_content_sequence.encode_for_inference(
  354. tokenizer, num_codebooks=model.config.num_codebooks
  355. )
  356. if encoded.size(1) > max_length - 2048:
  357. raise ValueError(f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}")
  358. encoded = encoded.to(device=device)
  359. logger.info(f"Encoded text: {text}")
  360. # Move temperature, top_p, repetition_penalty to device
  361. # This is important so that changing params doesn't trigger recompile
  362. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  363. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  364. repetition_penalty = torch.tensor(
  365. repetition_penalty, device=device, dtype=torch.float
  366. )
  367. for sample_idx in range(num_samples):
  368. if torch.cuda.is_available():
  369. torch.cuda.synchronize()
  370. global_encoded = []
  371. seg_idx = 0
  372. prompt_length = encoded.size(1)
  373. t0 = time.perf_counter()
  374. y = generate(
  375. model=model,
  376. prompt=encoded,
  377. max_new_tokens=max_new_tokens,
  378. audio_masks=audio_masks,
  379. audio_parts=audio_parts,
  380. decode_one_token=decode_one_token,
  381. temperature=temperature,
  382. top_p=top_p,
  383. repetition_penalty=repetition_penalty,
  384. )
  385. if sample_idx == 0 and seg_idx == 0 and compile:
  386. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  387. if torch.cuda.is_available():
  388. torch.cuda.synchronize()
  389. t = time.perf_counter() - t0
  390. tokens_generated = y.size(1) - prompt_length
  391. tokens_sec = tokens_generated / t
  392. logger.info(
  393. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  394. )
  395. logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
  396. if torch.cuda.is_available():
  397. logger.info(
  398. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  399. )
  400. # Put the generated tokens
  401. # since there is <im_end>, we remove last token
  402. codes = y[1:, prompt_length:-1].clone()
  403. assert (codes >= 0).all(), f"Negative code found"
  404. decoded = y[:, prompt_length:].clone()
  405. # But for global encoding, we should keep the <im_end> token
  406. global_encoded.append(decoded.cpu())
  407. assert (codes >= 0).all(), f"Negative code found: {codes}"
  408. yield GenerateResponse(action="sample", codes=codes, text=text)
  409. seg_idx += 1
  410. # This indicates the end of the current sample
  411. yield GenerateResponse(action="next")
  412. @dataclass
  413. class WrappedGenerateResponse:
  414. status: Literal["success", "error"]
  415. response: Optional[GenerateResponse | Exception] = None
  416. @dataclass
  417. class GenerateRequest:
  418. request: dict
  419. response_queue: queue.Queue
  420. def launch_thread_safe_queue(
  421. checkpoint_path,
  422. device,
  423. precision,
  424. compile: bool = False,
  425. ):
  426. input_queue = queue.Queue()
  427. init_event = threading.Event()
  428. def worker():
  429. model, decode_one_token = init_model(
  430. checkpoint_path, device, precision, compile=compile
  431. )
  432. with torch.device(device):
  433. model.setup_caches(
  434. max_batch_size=1,
  435. max_seq_len=model.config.max_seq_len,
  436. dtype=next(model.parameters()).dtype,
  437. )
  438. init_event.set()
  439. while True:
  440. item: GenerateRequest | None = input_queue.get()
  441. if item is None:
  442. break
  443. kwargs = item.request
  444. response_queue = item.response_queue
  445. try:
  446. for chunk in generate_long(
  447. model=model, decode_one_token=decode_one_token, **kwargs
  448. ):
  449. response_queue.put(
  450. WrappedGenerateResponse(status="success", response=chunk)
  451. )
  452. except Exception as e:
  453. logger.error(traceback.format_exc())
  454. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  455. threading.Thread(target=worker, daemon=True).start()
  456. init_event.wait()
  457. return input_queue
  458. @click.command()
  459. @click.option(
  460. "--text",
  461. type=str,
  462. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  463. )
  464. @click.option("--prompt-text", type=str, default=None, multiple=True)
  465. @click.option(
  466. "--prompt-tokens",
  467. type=click.Path(path_type=Path, exists=True),
  468. default=None,
  469. multiple=True,
  470. )
  471. @click.option("--num-samples", type=int, default=1)
  472. @click.option("--max-new-tokens", type=int, default=0)
  473. @click.option("--top-p", type=float, default=0.8)
  474. @click.option("--repetition-penalty", type=float, default=1.1)
  475. @click.option("--temperature", type=float, default=0.8)
  476. @click.option(
  477. "--checkpoint-path",
  478. type=click.Path(path_type=Path, exists=True),
  479. default="checkpoints/openaudio-s1-mini",
  480. )
  481. @click.option("--device", type=str, default="cuda")
  482. @click.option("--compile/--no-compile", default=False)
  483. @click.option("--seed", type=int, default=42)
  484. @click.option("--half/--no-half", default=False)
  485. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  486. @click.option("--chunk-length", type=int, default=300)
  487. @click.option("--output-dir", type=Path, default="temp")
  488. def main(
  489. text: str,
  490. prompt_text: Optional[list[str]],
  491. prompt_tokens: Optional[list[Path]],
  492. num_samples: int,
  493. max_new_tokens: int,
  494. top_p: int,
  495. repetition_penalty: float,
  496. temperature: float,
  497. checkpoint_path: Path,
  498. device: str,
  499. compile: bool,
  500. seed: int,
  501. half: bool,
  502. iterative_prompt: bool,
  503. chunk_length: int,
  504. output_dir: Path,
  505. ) -> None:
  506. os.makedirs(output_dir, exist_ok=True)
  507. precision = torch.half if half else torch.bfloat16
  508. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  509. raise ValueError(
  510. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  511. )
  512. logger.info("Loading model ...")
  513. t0 = time.time()
  514. model, decode_one_token = init_model(
  515. checkpoint_path, device, precision, compile=compile
  516. )
  517. with torch.device(device):
  518. model.setup_caches(
  519. max_batch_size=1,
  520. max_seq_len=model.config.max_seq_len,
  521. dtype=next(model.parameters()).dtype,
  522. )
  523. if torch.cuda.is_available():
  524. torch.cuda.synchronize()
  525. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  526. if prompt_tokens is not None:
  527. prompt_tokens = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
  528. torch.manual_seed(seed)
  529. if torch.cuda.is_available():
  530. torch.cuda.manual_seed(seed)
  531. generator = generate_long(
  532. model=model,
  533. device=device,
  534. decode_one_token=decode_one_token,
  535. text=text,
  536. num_samples=num_samples,
  537. max_new_tokens=max_new_tokens,
  538. top_p=top_p,
  539. repetition_penalty=repetition_penalty,
  540. temperature=temperature,
  541. compile=compile,
  542. iterative_prompt=iterative_prompt,
  543. chunk_length=chunk_length,
  544. prompt_text=prompt_text,
  545. prompt_tokens=prompt_tokens,
  546. )
  547. idx = 0
  548. codes = []
  549. for response in generator:
  550. if response.action == "sample":
  551. codes.append(response.codes)
  552. logger.info(f"Sampled text: {response.text}")
  553. elif response.action == "next":
  554. if codes:
  555. codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
  556. np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
  557. logger.info(f"Saved codes to {codes_npy_path}")
  558. logger.info(f"Next sample")
  559. codes = []
  560. idx += 1
  561. else:
  562. logger.error(f"Error: {response}")
  563. if __name__ == "__main__":
  564. main()