inference.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import numpy as np
  11. import torch
  12. import torch._inductor.config
  13. from loguru import logger
  14. from tqdm import tqdm
  15. from transformers import AutoTokenizer
  16. from fish_speech.content_sequence import (
  17. ContentSequence,
  18. TextPart,
  19. VQPart,
  20. )
  21. from fish_speech.text import split_text
  22. from fish_speech.tokenizer import IM_END_TOKEN
  23. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  24. torch._inductor.config.coordinate_descent_tuning = True
  25. torch._inductor.config.triton.unique_kernel_names = True
  26. if hasattr(torch._inductor.config, "fx_graph_cache"):
  27. # Experimental feature to reduce compilation times, will be on by default in future
  28. torch._inductor.config.fx_graph_cache = True
  29. from torch.nn.attention import SDPBackend, sdpa_kernel
  30. from fish_speech.models.text2semantic.llama import (
  31. DualARTransformer,
  32. NaiveTransformer,
  33. )
  34. def multinomial_sample_one_no_sync(
  35. probs_sort,
  36. ): # Does multinomial sampling without a cuda synchronization
  37. q = torch.empty_like(probs_sort).exponential_(1)
  38. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  39. def logits_to_probs(
  40. logits,
  41. previous_tokens: Optional[torch.Tensor] = None,
  42. temperature: torch.Tensor = 1.0,
  43. top_p: torch.Tensor = 1.0,
  44. repetition_penalty: torch.Tensor = 1.0,
  45. ) -> torch.Tensor:
  46. # Apply repetition penalty
  47. if previous_tokens is not None:
  48. previous_tokens = previous_tokens.long()
  49. score = torch.gather(logits, dim=0, index=previous_tokens)
  50. score = torch.where(
  51. score < 0, score * repetition_penalty, score / repetition_penalty
  52. )
  53. logits.scatter_(dim=0, index=previous_tokens, src=score)
  54. # Apply top-p sampling
  55. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  56. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  57. sorted_indices_to_remove = cum_probs > top_p
  58. sorted_indices_to_remove[0] = False # keep at least one option
  59. indices_to_remove = sorted_indices_to_remove.scatter(
  60. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  61. )
  62. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  63. logits = logits / max(temperature, 1e-5)
  64. probs = torch.nn.functional.softmax(logits, dim=-1)
  65. return probs
  66. def sample(
  67. logits,
  68. previous_tokens: Optional[torch.Tensor] = None,
  69. **sampling_kwargs,
  70. ) -> Tuple[torch.Tensor, torch.Tensor]:
  71. probs = logits_to_probs(
  72. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  73. )
  74. idx_next = multinomial_sample_one_no_sync(probs)
  75. return idx_next, probs
  76. def decode_one_token_ar(
  77. model: DualARTransformer,
  78. x: torch.Tensor,
  79. input_pos: torch.Tensor,
  80. previous_tokens: torch.Tensor = None,
  81. **sampling_kwargs,
  82. ) -> torch.Tensor:
  83. """
  84. Generate one token using dual autoregressive transformer for text-to-speech.
  85. First generates semantic tokens, then generates acoustic codebook tokens sequentially.
  86. Args:
  87. x: Input token tensor (1, num_codebooks+1, seq_len)
  88. input_pos: Position indices for input tokens (seq_len,)
  89. temperature/top_p/repetition_penalty: Sampling parameters (1, 1)
  90. previous_tokens: Previous tokens for repetition penalty (1, num_codebooks+1, history_seq_len)
  91. audio_masks/audio_parts: Audio conditioning tensors (num_codebooks, seq_len)
  92. Returns:
  93. Generated tokens tensor (num_codebooks+1, 1) - one token per codebook
  94. """
  95. x = model.forward_generate(x, input_pos)
  96. sampling_kwargs_main = sampling_kwargs.copy()
  97. codebooks = [
  98. sample(
  99. x.logits,
  100. previous_tokens=(
  101. previous_tokens[0] if previous_tokens is not None else None
  102. ), # Disable repetition penalty for the token codebook
  103. **sampling_kwargs_main,
  104. )[0]
  105. ]
  106. hidden_states = x.hidden_states
  107. # Cleanup the cache
  108. for layer in model.fast_layers:
  109. layer.attention.kv_cache.k_cache.fill_(0)
  110. layer.attention.kv_cache.v_cache.fill_(0)
  111. input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
  112. model.forward_generate_fast(hidden_states, input_pos)
  113. a = codebooks[0] - model.tokenizer.semantic_begin_id
  114. a[a < 0] = 0
  115. hidden_states = model.fast_embeddings(a)
  116. codebooks.append(a)
  117. for codebook_idx in range(1, model.config.num_codebooks):
  118. input_pos = torch.tensor(
  119. [codebook_idx], device=hidden_states.device, dtype=torch.long
  120. )
  121. logits = model.forward_generate_fast(hidden_states, input_pos)
  122. chunked_logits = logits[..., :1024]
  123. a = sample(
  124. chunked_logits,
  125. previous_tokens=(
  126. previous_tokens[codebook_idx + 1]
  127. if previous_tokens is not None
  128. else None
  129. ),
  130. **sampling_kwargs,
  131. )[0]
  132. hidden_states = model.fast_embeddings(a)
  133. codebooks.append(a)
  134. codebooks = torch.stack(codebooks, dim=0)
  135. return codebooks
  136. def decode_n_tokens(
  137. model: NaiveTransformer,
  138. cur_token: torch.Tensor,
  139. input_pos: torch.Tensor,
  140. num_new_tokens: int,
  141. decode_one_token=decode_one_token_ar,
  142. **sampling_kwargs,
  143. ):
  144. """
  145. Generate n tokens iteratively using the model.
  146. Args:
  147. model: The transformer model
  148. cur_token: Current token tensor of shape (1, num_codebooks+1, seq_len)
  149. input_pos: Current input position tensor
  150. num_new_tokens: Number of new tokens to generate
  151. semantic_ids: List of semantic token IDs
  152. decode_one_token: Function to decode one token
  153. **sampling_kwargs: Additional sampling parameters
  154. Returns:
  155. Generated tokens tensor of shape (num_codebooks+1, generated_len)
  156. """
  157. previous_tokens = torch.zeros(
  158. (model.config.num_codebooks + 1, model.config.max_seq_len),
  159. dtype=torch.int,
  160. device=cur_token.device,
  161. )
  162. for i in tqdm(range(num_new_tokens)):
  163. # We need to get windowed repeat penalty
  164. win_size = 16
  165. if i < win_size:
  166. window = previous_tokens[:, :win_size]
  167. else:
  168. window = previous_tokens[:, i - win_size : i]
  169. with sdpa_kernel(SDPBackend.MATH):
  170. next_token = decode_one_token(
  171. model=model,
  172. x=cur_token,
  173. input_pos=input_pos,
  174. previous_tokens=window,
  175. **sampling_kwargs,
  176. ).clone()
  177. input_pos += 1
  178. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  179. previous_tokens[:, i : i + 1] = next_token.view(
  180. model.config.num_codebooks + 1, -1
  181. )
  182. if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
  183. break
  184. return previous_tokens[:, : i + 1]
  185. @torch.no_grad()
  186. @torch.inference_mode()
  187. def generate(
  188. *,
  189. model: NaiveTransformer,
  190. prompt: torch.Tensor,
  191. max_new_tokens: int,
  192. decode_one_token=decode_one_token_ar,
  193. **sampling_kwargs,
  194. ) -> torch.Tensor:
  195. """
  196. Generate tokens from text prompt using the transformer model.
  197. Args:
  198. model: The transformer model for generation
  199. prompt: Input token tensor of shape (num_codebooks+1, seq_len)
  200. max_new_tokens: Maximum number of new tokens to generate
  201. decode_one_token: Function to decode one token at a time
  202. **sampling_kwargs: Additional sampling parameters (temperature, top_p, repetition_penalty)
  203. Returns:
  204. Generated sequence tensor of shape (num_codebooks+1, total_seq_len)
  205. where total_seq_len = original_seq_len + generated_tokens_len
  206. """
  207. T = prompt.size(1)
  208. if max_new_tokens:
  209. if T + max_new_tokens > model.config.max_seq_len:
  210. max_new_tokens = model.config.max_seq_len - T
  211. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  212. T_new = T + max_new_tokens
  213. else:
  214. T_new = model.config.max_seq_len
  215. max_new_tokens = T_new - T
  216. device, dtype = prompt.device, prompt.dtype
  217. codebook_dim = 1 + model.config.num_codebooks
  218. empty = torch.empty(
  219. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  220. )
  221. empty[:, :T] = prompt
  222. seq = empty
  223. input_pos = torch.arange(0, T, device=device)
  224. # Use non-accelerated version for now, to avoid compilation overhead
  225. prefill_decode = decode_one_token_ar
  226. first_token = prefill_decode(
  227. model,
  228. prompt.view(1, codebook_dim, -1),
  229. input_pos,
  230. **sampling_kwargs,
  231. )
  232. seq[:, T : T + 1] = first_token
  233. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  234. x = decode_n_tokens(
  235. model,
  236. first_token.view(1, codebook_dim, -1),
  237. input_pos,
  238. max_new_tokens - 1,
  239. decode_one_token=decode_one_token,
  240. **sampling_kwargs,
  241. )
  242. seq = seq[:, : T + 1 + x.size(1)]
  243. seq[:, T + 1 :] = x
  244. return seq
  245. def init_model(checkpoint_path, device, precision, compile=False):
  246. model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
  247. model = model.to(device=device, dtype=precision)
  248. logger.info(f"Restored model from checkpoint")
  249. if isinstance(model, DualARTransformer):
  250. decode_one_token = decode_one_token_ar
  251. logger.info("Using DualARTransformer")
  252. else:
  253. raise ValueError("Model is not a DualARTransformer")
  254. if compile:
  255. logger.info("Compiling function...")
  256. decode_one_token = torch.compile(
  257. decode_one_token,
  258. fullgraph=True,
  259. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  260. mode="reduce-overhead" if torch.cuda.is_available() else None,
  261. )
  262. return model.eval(), decode_one_token
  263. @dataclass
  264. class GenerateResponse:
  265. action: Literal["sample", "next"]
  266. codes: Optional[torch.Tensor] = None
  267. text: Optional[str] = None
  268. def generate_long(
  269. *,
  270. model,
  271. device: str | torch.device,
  272. decode_one_token: callable,
  273. text: str,
  274. num_samples: int = 1,
  275. max_new_tokens: int = 0,
  276. top_p: int = 0.8,
  277. repetition_penalty: float = 1.1,
  278. temperature: float = 0.8,
  279. compile: bool = False,
  280. iterative_prompt: bool = True,
  281. chunk_length: int = 150,
  282. prompt_text: Optional[str | list[str]] = None,
  283. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  284. ):
  285. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  286. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  287. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  288. use_prompt = prompt_text is not None and prompt_tokens is not None
  289. if use_prompt and isinstance(prompt_text, str):
  290. prompt_text = [prompt_text]
  291. prompt_tokens = [prompt_tokens]
  292. assert use_prompt is False or len(prompt_text) == len(
  293. prompt_tokens
  294. ), "Prompt text and tokens must have the same length"
  295. prompt_tokens = [i.cpu() for i in prompt_tokens]
  296. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  297. tokenizer = model.tokenizer
  298. base_content_sequence = ContentSequence(modality="interleave")
  299. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  300. max_length = model.config.max_seq_len
  301. if use_prompt:
  302. for t, c in zip(prompt_text, prompt_tokens):
  303. base_content_sequence.append(
  304. [
  305. TextPart(text=t),
  306. VQPart(codes=c),
  307. ],
  308. add_end=True,
  309. )
  310. encoded_prompts = base_content_sequence.encode_for_inference(
  311. tokenizer, num_codebooks=model.config.num_codebooks
  312. )
  313. if encoded_prompts.size(1) > max_length - 2048:
  314. raise ValueError(
  315. f"Prompt is too long: {encoded_prompts.size(1)} > {max_length - 2048}"
  316. )
  317. encoded = []
  318. for text in texts:
  319. content_sequence = ContentSequence(modality=None)
  320. content_sequence.append(TextPart(text=text))
  321. encoded.append(
  322. content_sequence.encode_for_inference(
  323. tokenizer, num_codebooks=model.config.num_codebooks
  324. )
  325. )
  326. logger.info(f"Encoded text: {text}")
  327. # Move temperature, top_p, repetition_penalty to device
  328. # This is important so that changing params doesn't trigger recompile
  329. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  330. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  331. repetition_penalty = torch.tensor(
  332. repetition_penalty, device=device, dtype=torch.float
  333. )
  334. for sample_idx in range(num_samples):
  335. if torch.cuda.is_available():
  336. torch.cuda.synchronize()
  337. global_encoded = []
  338. seg_idx = 0
  339. while seg_idx < len(encoded):
  340. logger.info(
  341. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  342. )
  343. seg = encoded[seg_idx]
  344. global_encoded.append(seg)
  345. if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
  346. cat_encoded = torch.cat(
  347. [encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
  348. )
  349. else:
  350. cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
  351. cat_encoded = cat_encoded.to(device=device)
  352. prompt_length = cat_encoded.size(1)
  353. t0 = time.perf_counter()
  354. y = generate(
  355. model=model,
  356. prompt=cat_encoded,
  357. max_new_tokens=max_new_tokens,
  358. decode_one_token=decode_one_token,
  359. temperature=temperature,
  360. top_p=top_p,
  361. repetition_penalty=repetition_penalty,
  362. )
  363. if sample_idx == 0 and seg_idx == 0 and compile:
  364. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  365. if torch.cuda.is_available():
  366. torch.cuda.synchronize()
  367. t = time.perf_counter() - t0
  368. tokens_generated = y.size(1) - prompt_length
  369. tokens_sec = tokens_generated / t
  370. logger.info(
  371. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  372. )
  373. logger.info(
  374. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  375. )
  376. if torch.cuda.is_available():
  377. logger.info(
  378. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  379. )
  380. # Put the generated tokens
  381. # since there is <im_end>, we remove last token
  382. codes = y[1:, prompt_length:-1].clone()
  383. assert (codes >= 0).all(), f"Negative code found"
  384. decoded = y[:, prompt_length:].clone()
  385. # But for global encoding, we should keep the <im_end> token
  386. global_encoded.append(decoded.cpu())
  387. assert (codes >= 0).all(), f"Negative code found: {codes}"
  388. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  389. seg_idx += 1
  390. # This indicates the end of the current sample
  391. yield GenerateResponse(action="next")
  392. @dataclass
  393. class WrappedGenerateResponse:
  394. status: Literal["success", "error"]
  395. response: Optional[GenerateResponse | Exception] = None
  396. @dataclass
  397. class GenerateRequest:
  398. request: dict
  399. response_queue: queue.Queue
  400. def launch_thread_safe_queue(
  401. checkpoint_path,
  402. device,
  403. precision,
  404. compile: bool = False,
  405. ):
  406. input_queue = queue.Queue()
  407. init_event = threading.Event()
  408. def worker():
  409. model, decode_one_token = init_model(
  410. checkpoint_path, device, precision, compile=compile
  411. )
  412. with torch.device(device):
  413. model.setup_caches(
  414. max_batch_size=1,
  415. max_seq_len=model.config.max_seq_len,
  416. dtype=next(model.parameters()).dtype,
  417. )
  418. init_event.set()
  419. while True:
  420. item: GenerateRequest | None = input_queue.get()
  421. if item is None:
  422. break
  423. kwargs = item.request
  424. response_queue = item.response_queue
  425. try:
  426. for chunk in generate_long(
  427. model=model, decode_one_token=decode_one_token, **kwargs
  428. ):
  429. response_queue.put(
  430. WrappedGenerateResponse(status="success", response=chunk)
  431. )
  432. except Exception as e:
  433. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  434. threading.Thread(target=worker, daemon=True).start()
  435. init_event.wait()
  436. return input_queue
  437. @click.command()
  438. @click.option(
  439. "--text",
  440. type=str,
  441. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  442. )
  443. @click.option("--prompt-text", type=str, default=None, multiple=True)
  444. @click.option(
  445. "--prompt-tokens",
  446. type=click.Path(path_type=Path, exists=True),
  447. default=None,
  448. multiple=True,
  449. )
  450. @click.option("--num-samples", type=int, default=1)
  451. @click.option("--max-new-tokens", type=int, default=0)
  452. @click.option("--top-p", type=float, default=0.8)
  453. @click.option("--repetition-penalty", type=float, default=1.1)
  454. @click.option("--temperature", type=float, default=0.8)
  455. @click.option(
  456. "--checkpoint-path",
  457. type=click.Path(path_type=Path, exists=True),
  458. default="checkpoints/openaudio-s1-mini",
  459. )
  460. @click.option("--device", type=str, default="cuda")
  461. @click.option("--compile/--no-compile", default=False)
  462. @click.option("--seed", type=int, default=42)
  463. @click.option("--half/--no-half", default=False)
  464. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  465. @click.option("--chunk-length", type=int, default=300)
  466. @click.option("--output-dir", type=Path, default="temp")
  467. def main(
  468. text: str,
  469. prompt_text: Optional[list[str]],
  470. prompt_tokens: Optional[list[Path]],
  471. num_samples: int,
  472. max_new_tokens: int,
  473. top_p: int,
  474. repetition_penalty: float,
  475. temperature: float,
  476. checkpoint_path: Path,
  477. device: str,
  478. compile: bool,
  479. seed: int,
  480. half: bool,
  481. iterative_prompt: bool,
  482. chunk_length: int,
  483. output_dir: Path,
  484. ) -> None:
  485. os.makedirs(output_dir, exist_ok=True)
  486. precision = torch.half if half else torch.bfloat16
  487. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  488. raise ValueError(
  489. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  490. )
  491. logger.info("Loading model ...")
  492. t0 = time.time()
  493. model, decode_one_token = init_model(
  494. checkpoint_path, device, precision, compile=compile
  495. )
  496. with torch.device(device):
  497. model.setup_caches(
  498. max_batch_size=1,
  499. max_seq_len=model.config.max_seq_len,
  500. dtype=next(model.parameters()).dtype,
  501. )
  502. if torch.cuda.is_available():
  503. torch.cuda.synchronize()
  504. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  505. if prompt_tokens is not None:
  506. prompt_tokens = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
  507. torch.manual_seed(seed)
  508. if torch.cuda.is_available():
  509. torch.cuda.manual_seed(seed)
  510. generator = generate_long(
  511. model=model,
  512. device=device,
  513. decode_one_token=decode_one_token,
  514. text=text,
  515. num_samples=num_samples,
  516. max_new_tokens=max_new_tokens,
  517. top_p=top_p,
  518. repetition_penalty=repetition_penalty,
  519. temperature=temperature,
  520. compile=compile,
  521. iterative_prompt=iterative_prompt,
  522. chunk_length=chunk_length,
  523. prompt_text=prompt_text,
  524. prompt_tokens=prompt_tokens,
  525. )
  526. idx = 0
  527. codes = []
  528. for response in generator:
  529. if response.action == "sample":
  530. codes.append(response.codes)
  531. logger.info(f"Sampled text: {response.text}")
  532. elif response.action == "next":
  533. if codes:
  534. codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
  535. np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
  536. logger.info(f"Saved codes to {codes_npy_path}")
  537. logger.info(f"Next sample")
  538. codes = []
  539. idx += 1
  540. else:
  541. logger.error(f"Error: {response}")
  542. if __name__ == "__main__":
  543. main()