inference.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. import traceback
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Callable, Literal, Optional, Tuple, Union
  9. import click
  10. import numpy as np
  11. import torch
  12. import torch._inductor.config
  13. from loguru import logger
  14. from tqdm import tqdm
  15. from transformers import AutoTokenizer
  16. from fish_speech.content_sequence import (
  17. ContentSequence,
  18. TextPart,
  19. VQPart,
  20. )
  21. from fish_speech.tokenizer import IM_END_TOKEN
  22. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  23. torch._inductor.config.coordinate_descent_tuning = True
  24. torch._inductor.config.triton.unique_kernel_names = True
  25. if hasattr(torch._inductor.config, "fx_graph_cache"):
  26. # Experimental feature to reduce compilation times, will be on by default in future
  27. torch._inductor.config.fx_graph_cache = True
  28. from torch.nn.attention import SDPBackend, sdpa_kernel
  29. from fish_speech.models.text2semantic.llama import (
  30. BaseTransformer,
  31. DualARTransformer,
  32. NaiveTransformer,
  33. )
  34. def multinomial_sample_one_no_sync(
  35. probs_sort,
  36. ): # Does multinomial sampling without a cuda synchronization
  37. q = torch.empty_like(probs_sort).exponential_(1)
  38. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  39. def logits_to_probs(
  40. logits,
  41. temperature: torch.Tensor,
  42. top_p: torch.Tensor,
  43. repetition_penalty: torch.Tensor,
  44. previous_tokens: Optional[torch.Tensor] = None,
  45. ) -> torch.Tensor:
  46. # Apply repetition penalty
  47. if previous_tokens is not None:
  48. previous_tokens = previous_tokens.long()
  49. score = torch.gather(logits, dim=-1, index=previous_tokens)
  50. score = torch.where(
  51. score < 0, score * repetition_penalty, score / repetition_penalty
  52. )
  53. logits.scatter_(dim=-1, index=previous_tokens, src=score)
  54. # Apply top-p sampling
  55. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  56. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  57. sorted_indices_to_remove = cum_probs > top_p
  58. sorted_indices_to_remove[0] = False # keep at least one option
  59. indices_to_remove = sorted_indices_to_remove.scatter(
  60. dim=-1, index=sorted_indices, src=sorted_indices_to_remove
  61. )
  62. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  63. logits = logits / torch.clip(temperature, min=1e-5)
  64. probs = torch.nn.functional.softmax(logits, dim=-1)
  65. return probs
  66. def sample(
  67. logits,
  68. temperature: torch.Tensor,
  69. top_p: torch.Tensor,
  70. repetition_penalty: torch.Tensor,
  71. previous_tokens: Optional[torch.Tensor] = None,
  72. ) -> Tuple[torch.Tensor, torch.Tensor]:
  73. probs = logits_to_probs(
  74. logits=logits[0, -1],
  75. temperature=temperature,
  76. top_p=top_p,
  77. repetition_penalty=repetition_penalty,
  78. previous_tokens=previous_tokens,
  79. )
  80. idx_next = multinomial_sample_one_no_sync(probs)
  81. return idx_next, probs
  82. def decode_one_token_ar(
  83. model: DualARTransformer,
  84. x: torch.Tensor,
  85. input_pos: torch.Tensor,
  86. temperature: torch.Tensor,
  87. top_p: torch.Tensor,
  88. repetition_penalty: torch.Tensor,
  89. audio_masks: torch.Tensor,
  90. audio_parts: torch.Tensor,
  91. previous_tokens: Optional[torch.Tensor] = None,
  92. ) -> torch.Tensor:
  93. # print(x, torch.count_nonzero(vq_masks))
  94. forward_result = model.forward_generate(
  95. x,
  96. input_pos,
  97. audio_masks=audio_masks,
  98. audio_parts=audio_parts,
  99. )
  100. logits = forward_result.logits # [:, -1:]
  101. hidden_states = forward_result.hidden_states # [:, -1:]
  102. codebooks = [
  103. sample(
  104. logits,
  105. temperature=temperature,
  106. top_p=top_p,
  107. repetition_penalty=repetition_penalty,
  108. previous_tokens=(
  109. previous_tokens[:, 0] if previous_tokens is not None else None
  110. ),
  111. )[0]
  112. ]
  113. # Only clear cache for fast_layers, avoid clearing main model cache
  114. for layer in model.fast_layers:
  115. if hasattr(layer, "attention") and hasattr(layer.attention, "kv_cache"):
  116. layer.attention.kv_cache.k_cache.fill_(0)
  117. layer.attention.kv_cache.v_cache.fill_(0)
  118. input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
  119. model.forward_generate_fast(hidden_states, input_pos)
  120. a = codebooks[0] - model.tokenizer.semantic_begin_id
  121. a[a < 0] = 0
  122. hidden_states = model.fast_embeddings(a)
  123. codebooks.append(a)
  124. for codebook_idx in range(1, model.config.num_codebooks):
  125. input_pos = torch.tensor(
  126. [codebook_idx], device=hidden_states.device, dtype=torch.long
  127. )
  128. logits = model.forward_generate_fast(hidden_states, input_pos)
  129. short_logits = logits[:, :, :1024]
  130. # Convert logits to probs
  131. a = sample(
  132. short_logits,
  133. temperature=temperature,
  134. top_p=top_p,
  135. repetition_penalty=repetition_penalty,
  136. previous_tokens=(
  137. previous_tokens[codebook_idx + 1]
  138. if previous_tokens is not None
  139. else None
  140. ),
  141. )[0]
  142. hidden_states = model.fast_embeddings(a)
  143. codebooks.append(a)
  144. codebooks = torch.stack(codebooks, dim=1)
  145. # Only delete references, let Python GC handle cleanup
  146. del logits, hidden_states, forward_result
  147. return codebooks.T
  148. def decode_n_tokens(
  149. model: DualARTransformer,
  150. cur_token: torch.Tensor,
  151. input_pos: torch.Tensor,
  152. num_new_tokens: int,
  153. temperature: torch.Tensor,
  154. top_p: torch.Tensor,
  155. repetition_penalty: torch.Tensor,
  156. audio_masks: torch.Tensor,
  157. audio_parts: torch.Tensor,
  158. decode_one_token=decode_one_token_ar,
  159. ):
  160. previous_tokens = torch.zeros(
  161. (model.config.num_codebooks + 1, model.config.max_seq_len),
  162. dtype=torch.int,
  163. device=cur_token.device,
  164. )
  165. for i in tqdm(range(num_new_tokens)):
  166. # We need to get windowed repeat penalty
  167. win_size = 16
  168. if i < win_size:
  169. window = previous_tokens[:, :win_size]
  170. else:
  171. window = previous_tokens[:, i - win_size : i]
  172. with sdpa_kernel(
  173. SDPBackend.MATH
  174. ): # Actually better for Inductor to codegen attention here
  175. next_token = decode_one_token(
  176. model=model,
  177. x=cur_token,
  178. input_pos=input_pos,
  179. previous_tokens=window,
  180. temperature=temperature,
  181. top_p=top_p,
  182. repetition_penalty=repetition_penalty,
  183. audio_masks=audio_masks,
  184. audio_parts=audio_parts,
  185. ).clone()
  186. input_pos += 1
  187. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  188. previous_tokens[:, i : i + 1] = next_token.view(
  189. model.config.num_codebooks + 1, -1
  190. )
  191. if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
  192. break
  193. # Only clean up the large tensor
  194. del cur_token
  195. return previous_tokens[:, : i + 1]
  196. @torch.no_grad()
  197. @torch.inference_mode()
  198. def generate(
  199. *,
  200. model: DualARTransformer,
  201. prompt: torch.Tensor,
  202. max_new_tokens: int,
  203. audio_masks: torch.Tensor,
  204. audio_parts: torch.Tensor,
  205. decode_one_token=decode_one_token_ar,
  206. num_samples: int = 1,
  207. **sampling_kwargs,
  208. ):
  209. """
  210. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  211. """
  212. # create an empty tensor of the expected final shape and fill in the current tokens
  213. T = prompt.size(1)
  214. prompt = prompt[None].repeat(num_samples, 1, 1)
  215. if T >= model.config.max_seq_len:
  216. raise ValueError(
  217. f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
  218. )
  219. if max_new_tokens:
  220. if T + max_new_tokens > model.config.max_seq_len:
  221. max_new_tokens = model.config.max_seq_len - T
  222. T_new = T + max_new_tokens
  223. else:
  224. T_new = model.config.max_seq_len
  225. max_new_tokens = T_new - T
  226. device, dtype = prompt.device, prompt.dtype
  227. # Critical fix: Only set up cache on first run or when necessary
  228. if not hasattr(model, "_cache_setup_done") or not model._cache_setup_done:
  229. with torch.device(device):
  230. model.setup_caches(
  231. max_batch_size=1, # Fixed to 1, avoid dynamic changes
  232. max_seq_len=model.config.max_seq_len,
  233. dtype=next(model.parameters()).dtype,
  234. )
  235. model._cache_setup_done = True
  236. codebook_dim = 1 + model.config.num_codebooks
  237. # Create new tensor each time, but try to reuse memory
  238. input_pos = torch.arange(0, T, device=device, dtype=torch.long)
  239. empty = torch.empty(
  240. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  241. )
  242. empty[:, :T] = prompt
  243. seq = empty
  244. # Use pre-created fixed parameter tensors
  245. temperature = getattr(
  246. model, "fixed_temperature", torch.tensor(0.8, device=device, dtype=torch.float)
  247. )
  248. top_p = getattr(
  249. model, "fixed_top_p", torch.tensor(0.8, device=device, dtype=torch.float)
  250. )
  251. repetition_penalty = getattr(
  252. model,
  253. "fixed_repetition_penalty",
  254. torch.tensor(1.1, device=device, dtype=torch.float),
  255. )
  256. # If different parameter values are needed, directly modify existing tensors
  257. temp_val = sampling_kwargs.get("temperature", 0.7)
  258. top_p_val = sampling_kwargs.get("top_p", 0.7)
  259. rep_val = sampling_kwargs.get("repetition_penalty", 1.5)
  260. if abs(temperature.item() - temp_val) > 1e-6:
  261. temperature.fill_(temp_val)
  262. if abs(top_p.item() - top_p_val) > 1e-6:
  263. top_p.fill_(top_p_val)
  264. if abs(repetition_penalty.item() - rep_val) > 1e-6:
  265. repetition_penalty.fill_(rep_val)
  266. prefill_decode = decode_one_token_ar
  267. first_token = prefill_decode(
  268. model,
  269. prompt.view(1, codebook_dim, -1),
  270. input_pos,
  271. temperature,
  272. top_p,
  273. repetition_penalty,
  274. audio_masks,
  275. audio_parts,
  276. )
  277. seq[:, T : T + 1] = first_token
  278. # Recreate input_pos
  279. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  280. x = decode_n_tokens(
  281. model,
  282. first_token.view(1, codebook_dim, -1),
  283. input_pos,
  284. max_new_tokens - 1,
  285. temperature=temperature,
  286. top_p=top_p,
  287. repetition_penalty=repetition_penalty,
  288. audio_masks=audio_masks,
  289. audio_parts=audio_parts,
  290. decode_one_token=decode_one_token,
  291. )
  292. seq = seq[:, : T + 1 + x.size(1)]
  293. seq[:, T + 1 :] = x
  294. # Clean up temporary variables
  295. del first_token, x, prompt, empty, input_pos
  296. return seq
  297. def init_model(checkpoint_path, device, precision, compile=False):
  298. model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
  299. model = model.to(device=device, dtype=precision)
  300. logger.info(f"Restored model from checkpoint")
  301. if isinstance(model, DualARTransformer):
  302. decode_one_token = decode_one_token_ar
  303. prefill_n_tokens = decode_one_token_ar
  304. logger.info("Using DualARTransformer")
  305. else:
  306. raise ValueError("Unsupported model type")
  307. # Pre-create fixed parameter tensors to avoid runtime creation
  308. model.fixed_temperature = torch.tensor(0.7, device=device, dtype=torch.float)
  309. model.fixed_top_p = torch.tensor(0.7, device=device, dtype=torch.float)
  310. model.fixed_repetition_penalty = torch.tensor(1.5, device=device, dtype=torch.float)
  311. # Mark whether cache has been initialized
  312. model._cache_setup_done = False
  313. if compile:
  314. logger.info("Compiling function...")
  315. decode_one_token = torch.compile(
  316. decode_one_token,
  317. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  318. mode="reduce-overhead" if torch.cuda.is_available() else None,
  319. fullgraph=True,
  320. )
  321. return model.eval(), decode_one_token
  322. @dataclass
  323. class GenerateResponse:
  324. action: Literal["sample", "next"]
  325. codes: Optional[torch.Tensor] = None
  326. text: Optional[str] = None
  327. def generate_long(
  328. *,
  329. model,
  330. device: Union[str, torch.device],
  331. decode_one_token: Callable,
  332. text: str,
  333. num_samples: int = 1,
  334. max_new_tokens: int = 0,
  335. top_p: float = 0.8,
  336. repetition_penalty: float = 1.1,
  337. temperature: float = 0.8,
  338. compile: bool = False,
  339. iterative_prompt: bool = True,
  340. chunk_length: int = 512,
  341. prompt_text: Optional[Union[str, list[str]]] = None,
  342. prompt_tokens: Optional[Union[torch.Tensor, list[torch.Tensor]]] = None,
  343. ):
  344. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  345. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  346. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  347. use_prompt = prompt_text is not None and prompt_tokens is not None
  348. if use_prompt and isinstance(prompt_text, str):
  349. prompt_text = [prompt_text]
  350. prompt_tokens = [prompt_tokens]
  351. if use_prompt:
  352. assert len(prompt_text) == len(
  353. prompt_tokens
  354. ), "Prompt text and tokens must have the same length"
  355. if prompt_tokens:
  356. prompt_tokens = [i.cpu() for i in prompt_tokens]
  357. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  358. tokenizer = model.tokenizer
  359. base_content_sequence = ContentSequence(modality="interleave")
  360. max_length = model.config.max_seq_len
  361. if use_prompt:
  362. for t, c in zip(prompt_text, prompt_tokens):
  363. base_content_sequence.append(
  364. [
  365. TextPart(text=t),
  366. VQPart(codes=c),
  367. ],
  368. add_end=True,
  369. speaker=0,
  370. )
  371. base_content_sequence.append(
  372. [
  373. TextPart(text=text),
  374. ],
  375. add_end=False,
  376. speaker=0,
  377. )
  378. encoded, audio_masks, audio_parts = base_content_sequence.encode_for_inference(
  379. tokenizer, num_codebooks=model.config.num_codebooks
  380. )
  381. if encoded.size(1) > max_length - 2048:
  382. raise ValueError(f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}")
  383. encoded = encoded.to(device=device)
  384. logger.info(f"Encoded text: {text}")
  385. for sample_idx in range(num_samples):
  386. if torch.cuda.is_available():
  387. torch.cuda.synchronize()
  388. global_encoded = []
  389. seg_idx = 0
  390. prompt_length = encoded.size(1)
  391. t0 = time.perf_counter()
  392. y = generate(
  393. model=model,
  394. prompt=encoded,
  395. max_new_tokens=max_new_tokens,
  396. audio_masks=audio_masks,
  397. audio_parts=audio_parts,
  398. decode_one_token=decode_one_token,
  399. temperature=temperature,
  400. top_p=top_p,
  401. repetition_penalty=repetition_penalty,
  402. )
  403. if sample_idx == 0 and seg_idx == 0 and compile:
  404. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  405. if torch.cuda.is_available():
  406. torch.cuda.synchronize()
  407. t = time.perf_counter() - t0
  408. tokens_generated = y.size(1) - prompt_length
  409. tokens_sec = tokens_generated / t
  410. logger.info(
  411. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  412. )
  413. logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
  414. if torch.cuda.is_available():
  415. logger.info(
  416. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  417. )
  418. # Put the generated tokens
  419. codes = y[1:, prompt_length:-1].clone()
  420. assert (codes >= 0).all(), f"Negative code found"
  421. decoded = y[:, prompt_length:].clone()
  422. global_encoded.append(decoded.cpu())
  423. assert (codes >= 0).all(), f"Negative code found: {codes}"
  424. yield GenerateResponse(action="sample", codes=codes, text=text)
  425. seg_idx += 1
  426. # Force GPU memory cleanup
  427. del y, decoded, codes
  428. yield GenerateResponse(action="next")
  429. @dataclass
  430. class WrappedGenerateResponse:
  431. status: Literal["success", "error"]
  432. response: Optional[Union[GenerateResponse, Exception]] = None
  433. @dataclass
  434. class GenerateRequest:
  435. request: dict
  436. response_queue: queue.Queue
  437. def launch_thread_safe_queue(
  438. checkpoint_path,
  439. device,
  440. precision,
  441. compile: bool = False,
  442. ):
  443. input_queue = queue.Queue()
  444. init_event = threading.Event()
  445. def worker():
  446. model, decode_one_token = init_model(
  447. checkpoint_path, device, precision, compile=compile
  448. )
  449. with torch.device(device):
  450. model.setup_caches(
  451. max_batch_size=1,
  452. max_seq_len=model.config.max_seq_len,
  453. dtype=next(model.parameters()).dtype,
  454. )
  455. init_event.set()
  456. while True:
  457. item: GenerateRequest | None = input_queue.get()
  458. if item is None:
  459. break
  460. kwargs = item.request
  461. response_queue = item.response_queue
  462. try:
  463. for chunk in generate_long(
  464. model=model, decode_one_token=decode_one_token, **kwargs
  465. ):
  466. response_queue.put(
  467. WrappedGenerateResponse(status="success", response=chunk)
  468. )
  469. # Only clear cache after complete request batch
  470. if torch.cuda.is_available():
  471. torch.cuda.empty_cache()
  472. except Exception as e:
  473. logger.error(traceback.format_exc())
  474. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  475. # Clear cache on error
  476. if torch.cuda.is_available():
  477. torch.cuda.empty_cache()
  478. threading.Thread(target=worker, daemon=True).start()
  479. init_event.wait()
  480. return input_queue
  481. @click.command()
  482. @click.option(
  483. "--text",
  484. type=str,
  485. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  486. )
  487. @click.option("--prompt-text", type=str, default=None, multiple=True)
  488. @click.option(
  489. "--prompt-tokens",
  490. type=click.Path(path_type=Path, exists=True),
  491. default=None,
  492. multiple=True,
  493. )
  494. @click.option("--num-samples", type=int, default=1)
  495. @click.option("--max-new-tokens", type=int, default=0)
  496. @click.option("--top-p", type=float, default=0.8)
  497. @click.option("--repetition-penalty", type=float, default=1.1)
  498. @click.option("--temperature", type=float, default=0.8)
  499. @click.option(
  500. "--checkpoint-path",
  501. type=click.Path(path_type=Path, exists=True),
  502. default="checkpoints/openaudio-s1-mini",
  503. )
  504. @click.option("--device", type=str, default="cuda")
  505. @click.option("--compile/--no-compile", default=False)
  506. @click.option("--seed", type=int, default=42)
  507. @click.option("--half/--no-half", default=False)
  508. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  509. @click.option("--chunk-length", type=int, default=300)
  510. @click.option("--output-dir", type=Path, default="temp")
  511. def main(
  512. text: str,
  513. prompt_text: Optional[tuple[str, ...]],
  514. prompt_tokens: Optional[tuple[Path, ...]],
  515. num_samples: int,
  516. max_new_tokens: int,
  517. top_p: int,
  518. repetition_penalty: float,
  519. temperature: float,
  520. checkpoint_path: Path,
  521. device: str,
  522. compile: bool,
  523. seed: int,
  524. half: bool,
  525. iterative_prompt: bool,
  526. chunk_length: int,
  527. output_dir: Path,
  528. ) -> None:
  529. os.makedirs(output_dir, exist_ok=True)
  530. precision = torch.half if half else torch.bfloat16
  531. if (
  532. prompt_text is not None
  533. and prompt_tokens is not None
  534. and len(prompt_text) != len(prompt_tokens)
  535. ):
  536. raise ValueError(
  537. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  538. )
  539. logger.info("Loading model ...")
  540. t0 = time.time()
  541. model, decode_one_token = init_model(
  542. checkpoint_path, device, precision, compile=compile
  543. )
  544. with torch.device(device):
  545. model.setup_caches(
  546. max_batch_size=1,
  547. max_seq_len=model.config.max_seq_len,
  548. dtype=next(model.parameters()).dtype,
  549. )
  550. if torch.cuda.is_available():
  551. torch.cuda.synchronize()
  552. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  553. prompt_tokens_list = None
  554. if prompt_tokens is not None:
  555. prompt_tokens_list = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
  556. torch.manual_seed(seed)
  557. if torch.cuda.is_available():
  558. torch.cuda.manual_seed(seed)
  559. generator = generate_long(
  560. model=model,
  561. device=device,
  562. decode_one_token=decode_one_token,
  563. text=text,
  564. num_samples=num_samples,
  565. max_new_tokens=max_new_tokens,
  566. top_p=top_p,
  567. repetition_penalty=repetition_penalty,
  568. temperature=temperature,
  569. compile=compile,
  570. iterative_prompt=iterative_prompt,
  571. chunk_length=chunk_length,
  572. prompt_text=list(prompt_text) if prompt_text else None,
  573. prompt_tokens=prompt_tokens_list,
  574. )
  575. idx = 0
  576. codes = []
  577. for response in generator:
  578. if response.action == "sample":
  579. codes.append(response.codes)
  580. logger.info(f"Sampled text: {response.text}")
  581. elif response.action == "next":
  582. if codes:
  583. codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
  584. np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
  585. logger.info(f"Saved codes to {codes_npy_path}")
  586. logger.info(f"Next sample")
  587. codes = []
  588. idx += 1
  589. else:
  590. logger.error(f"Error: {response}")
  591. if __name__ == "__main__":
  592. main()