generate.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import hydra
  11. import numpy as np
  12. import torch
  13. import torch._dynamo.config
  14. import torch._inductor.config
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  18. from fish_speech.text import clean_text, split_text
  19. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  20. torch._inductor.config.coordinate_descent_tuning = True
  21. torch._inductor.config.triton.unique_kernel_names = True
  22. if hasattr(torch._inductor.config, "fx_graph_cache"):
  23. # Experimental feature to reduce compilation times, will be on by default in future
  24. torch._inductor.config.fx_graph_cache = True
  25. from fish_speech.models.text2semantic.llama import (
  26. BaseTransformer,
  27. DualARTransformer,
  28. NaiveTransformer,
  29. )
  30. def multinomial_sample_one_no_sync(
  31. probs_sort,
  32. ): # Does multinomial sampling without a cuda synchronization
  33. q = torch.empty_like(probs_sort).exponential_(1)
  34. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  35. def logits_to_probs(
  36. logits,
  37. previous_tokens: Optional[torch.Tensor] = None,
  38. temperature: torch.Tensor = 1.0,
  39. top_p: torch.Tensor = 1.0,
  40. repetition_penalty: torch.Tensor = 1.0,
  41. ) -> torch.Tensor:
  42. # Apply repetition penalty
  43. if previous_tokens is not None:
  44. previous_tokens = previous_tokens.long()
  45. score = torch.gather(logits, dim=0, index=previous_tokens)
  46. score = torch.where(
  47. score < 0, score * repetition_penalty, score / repetition_penalty
  48. )
  49. logits.scatter_(dim=0, index=previous_tokens, src=score)
  50. # Apply top-p sampling
  51. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  52. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  53. sorted_indices_to_remove = cum_probs > top_p
  54. sorted_indices_to_remove[0] = False # keep at least one option
  55. indices_to_remove = sorted_indices_to_remove.scatter(
  56. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  57. )
  58. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  59. logits = logits / max(temperature, 1e-5)
  60. probs = torch.nn.functional.softmax(logits, dim=-1)
  61. return probs
  62. def sample(
  63. logits,
  64. previous_tokens: Optional[torch.Tensor] = None,
  65. **sampling_kwargs,
  66. ) -> Tuple[torch.Tensor, torch.Tensor]:
  67. probs = logits_to_probs(
  68. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  69. )
  70. idx_next = multinomial_sample_one_no_sync(probs)
  71. return idx_next, probs
  72. def decode_one_token_ar(
  73. model: DualARTransformer,
  74. x: torch.Tensor,
  75. input_pos: torch.Tensor,
  76. previous_tokens: torch.Tensor = None,
  77. **sampling_kwargs,
  78. ) -> torch.Tensor:
  79. x = model.forward_generate(x, input_pos)
  80. sampling_kwargs_main = sampling_kwargs.copy()
  81. sampling_kwargs_main["temperature"] = 0.1
  82. sampling_kwargs_main["top_p"] = 0.1
  83. sampling_kwargs_main["repetition_penalty"] = 1.0
  84. codebooks = [
  85. sample(
  86. x.logits,
  87. previous_tokens=None, # Disable repetition penalty for the token codebook
  88. **sampling_kwargs_main,
  89. )[0]
  90. ]
  91. x = x.hidden_states
  92. # Cleanup the cache
  93. for layer in model.fast_layers:
  94. layer.attention.kv_cache.k_cache.fill_(0)
  95. layer.attention.kv_cache.v_cache.fill_(0)
  96. for codebook_idx in range(model.config.num_codebooks):
  97. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  98. logits = model.forward_generate_fast(x, input_pos)
  99. a = sample(
  100. logits,
  101. previous_tokens=(
  102. previous_tokens[codebook_idx + 1]
  103. if previous_tokens is not None
  104. else None
  105. ),
  106. **sampling_kwargs,
  107. )[0]
  108. x = model.fast_embeddings(a)
  109. codebooks.append(a)
  110. return torch.stack(codebooks, dim=0)
  111. def decode_one_token_naive(
  112. model: NaiveTransformer,
  113. x: torch.Tensor,
  114. input_pos: torch.Tensor,
  115. previous_tokens: torch.Tensor = None,
  116. **sampling_kwargs,
  117. ) -> torch.Tensor:
  118. x = model.forward_generate(x, input_pos)
  119. sampling_kwargs_main = sampling_kwargs.copy()
  120. sampling_kwargs_main["temperature"] = 0.1
  121. sampling_kwargs_main["top_p"] = 0.1
  122. sampling_kwargs_main["repetition_penalty"] = 1.0
  123. codebooks = [
  124. sample(
  125. x.logits,
  126. previous_tokens=None, # Disable repetition penalty for the token codebook
  127. **sampling_kwargs_main,
  128. )[0]
  129. ]
  130. for i in range(model.config.num_codebooks):
  131. codebooks.append(
  132. sample(
  133. x.codebook_logits[:, :, i],
  134. previous_tokens=(
  135. previous_tokens[i + 1] if previous_tokens is not None else None
  136. ),
  137. **sampling_kwargs,
  138. )[0]
  139. )
  140. return torch.stack(codebooks, dim=0)
  141. def decode_n_tokens(
  142. model: NaiveTransformer,
  143. cur_token: torch.Tensor,
  144. input_pos: torch.Tensor,
  145. num_new_tokens: int,
  146. im_end_id: int = 4,
  147. decode_one_token=decode_one_token_naive,
  148. **sampling_kwargs,
  149. ):
  150. previous_tokens = torch.zeros(
  151. (model.config.num_codebooks + 1, model.config.max_seq_len),
  152. dtype=torch.int,
  153. device=cur_token.device,
  154. )
  155. for i in tqdm(range(num_new_tokens)):
  156. # We need to get windowed repeat penalty
  157. win_size = 16
  158. if i < win_size:
  159. window = previous_tokens[:, :win_size]
  160. else:
  161. window = previous_tokens[:, i - win_size : i]
  162. with (
  163. torch.backends.cuda.sdp_kernel(
  164. enable_flash=False, enable_mem_efficient=False, enable_math=True
  165. )
  166. if torch.cuda.is_available()
  167. else nullcontext()
  168. ): # Actually better for Inductor to codegen attention here
  169. next_token = decode_one_token(
  170. model=model,
  171. x=cur_token,
  172. input_pos=input_pos,
  173. previous_tokens=window,
  174. **sampling_kwargs,
  175. )
  176. input_pos += 1
  177. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  178. previous_tokens[:, i : i + 1] = next_token.view(
  179. model.config.num_codebooks + 1, -1
  180. )
  181. if cur_token[0, 0, -1] == im_end_id:
  182. break
  183. return previous_tokens[:, : i + 1]
  184. @torch.no_grad()
  185. @torch.inference_mode()
  186. def generate(
  187. *,
  188. model: NaiveTransformer,
  189. prompt: torch.Tensor,
  190. max_new_tokens: int,
  191. im_end_id: int = 4,
  192. decode_one_token=decode_one_token_naive,
  193. **sampling_kwargs,
  194. ) -> torch.Tensor:
  195. """
  196. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  197. """
  198. # create an empty tensor of the expected final shape and fill in the current tokens
  199. T = prompt.size(1)
  200. device, dtype = prompt.device, prompt.dtype
  201. codebook_dim = 1 + model.config.num_codebooks
  202. # create an empty tensor of the expected final shape and fill in the current tokens
  203. empty = torch.empty((codebook_dim, max_new_tokens), dtype=dtype, device=device)
  204. empty[:, :T] = prompt
  205. seq = empty
  206. input_pos = torch.arange(0, T, device=device)
  207. # Use non-accelerated version for now, to avoid compilation overhead
  208. prefill_decode = (
  209. decode_one_token_naive
  210. if isinstance(model, NaiveTransformer)
  211. else decode_one_token_ar
  212. )
  213. next_token = prefill_decode(
  214. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  215. )
  216. seq[:, T : T + 1] = next_token
  217. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  218. x = decode_n_tokens(
  219. model,
  220. next_token.view(1, codebook_dim, -1),
  221. input_pos,
  222. max_new_tokens - 1,
  223. im_end_id=im_end_id,
  224. decode_one_token=decode_one_token,
  225. **sampling_kwargs,
  226. )
  227. # x = torch.cat(generated_tokens, dim=1)
  228. seq = seq[:, : T + 1 + x.size(1)]
  229. seq[:, T + 1 :] = x
  230. return seq
  231. def encode_tokens(
  232. tokenizer,
  233. string,
  234. device="cuda",
  235. prompt_tokens=None,
  236. num_codebooks=4,
  237. ):
  238. string = clean_text(string)
  239. string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
  240. new_tokens = tokenizer.encode(
  241. string,
  242. add_special_tokens=False,
  243. max_length=10**6,
  244. truncation=False,
  245. )
  246. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  247. # Codebooks
  248. zeros = (
  249. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  250. * CODEBOOK_PAD_TOKEN_ID
  251. )
  252. prompt = torch.cat((tokens, zeros), dim=0)
  253. if prompt_tokens is None:
  254. return prompt
  255. # Get prompt tokens
  256. if prompt_tokens.ndim == 3:
  257. assert (
  258. prompt_tokens.shape[0] == 1
  259. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  260. prompt_tokens = prompt_tokens[0]
  261. assert prompt_tokens.ndim == 2
  262. data = prompt_tokens + 1
  263. if prompt_tokens.shape[0] > num_codebooks:
  264. logger.warning(
  265. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  266. )
  267. data = data[:num_codebooks]
  268. # Add pad token for each codebook
  269. data = torch.cat(
  270. (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
  271. dim=1,
  272. )
  273. # Since 1.0, we use <|semantic|>
  274. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  275. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  276. main_token_ids = (
  277. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  278. )
  279. main_token_ids[0, -1] = end_token_id
  280. data = torch.cat((main_token_ids, data), dim=0)
  281. prompt = torch.cat((prompt, data), dim=1)
  282. return prompt
  283. def load_model(checkpoint_path, device, precision, compile=False):
  284. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  285. checkpoint_path, load_weights=True
  286. )
  287. model = model.to(device=device, dtype=precision)
  288. logger.info(f"Restored model from checkpoint")
  289. if isinstance(model, DualARTransformer):
  290. decode_one_token = decode_one_token_ar
  291. logger.info("Using DualARTransformer")
  292. else:
  293. decode_one_token = decode_one_token_naive
  294. logger.info("Using NaiveTransformer")
  295. if compile:
  296. logger.info("Compiling function...")
  297. decode_one_token = torch.compile(
  298. decode_one_token,
  299. fullgraph=True,
  300. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  301. mode="reduce-overhead" if torch.cuda.is_available() else None,
  302. )
  303. return model.eval(), decode_one_token
  304. @dataclass
  305. class GenerateResponse:
  306. action: Literal["sample", "next"]
  307. codes: Optional[torch.Tensor] = None
  308. text: Optional[str] = None
  309. def generate_long(
  310. *,
  311. model,
  312. device: str | torch.device,
  313. decode_one_token: callable,
  314. text: str,
  315. num_samples: int = 1,
  316. max_new_tokens: int = 0,
  317. top_p: int = 0.7,
  318. repetition_penalty: float = 1.5,
  319. temperature: float = 0.7,
  320. compile: bool = False,
  321. iterative_prompt: bool = True,
  322. max_length: int = 2048,
  323. chunk_length: int = 150,
  324. prompt_text: Optional[str | list[str]] = None,
  325. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  326. ):
  327. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  328. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  329. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  330. use_prompt = prompt_text is not None and prompt_tokens is not None
  331. if use_prompt and isinstance(prompt_text, str):
  332. prompt_text = [prompt_text]
  333. prompt_tokens = [prompt_tokens]
  334. assert use_prompt is False or len(prompt_text) == len(
  335. prompt_tokens
  336. ), "Prompt text and tokens must have the same length"
  337. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  338. tokenizer = model.tokenizer
  339. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  340. encoded = []
  341. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  342. encoded_prompts = []
  343. if use_prompt:
  344. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  345. encoded_prompts.append(
  346. encode_tokens(
  347. tokenizer,
  348. string=t,
  349. device=device,
  350. prompt_tokens=c,
  351. num_codebooks=model.config.num_codebooks,
  352. )
  353. )
  354. for idx, text in enumerate(texts):
  355. encoded.append(
  356. encode_tokens(
  357. tokenizer,
  358. string=text,
  359. device=device,
  360. num_codebooks=model.config.num_codebooks,
  361. )
  362. )
  363. logger.info(f"Encoded text: {text}")
  364. # Move temperature, top_p, repetition_penalty to device
  365. # This is important so that changing params doesn't trigger recompile
  366. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  367. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  368. repetition_penalty = torch.tensor(
  369. repetition_penalty, device=device, dtype=torch.float
  370. )
  371. for sample_idx in range(num_samples):
  372. if torch.cuda.is_available():
  373. torch.cuda.synchronize()
  374. global_encoded = []
  375. seg_idx = 0
  376. while seg_idx < len(encoded):
  377. logger.info(
  378. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  379. )
  380. seg = encoded[seg_idx]
  381. global_encoded.append(seg)
  382. lengths = reversed([seg.size(1) for seg in global_encoded])
  383. # Pick last 2000 tokens
  384. count = 0
  385. for i, length in enumerate(lengths):
  386. count += length
  387. if count + length > max_length - 1024 - sum(
  388. t.shape[1] for t in encoded_prompts
  389. ):
  390. break
  391. if i != 0 and i % 2 == 0:
  392. i -= 1
  393. # Rotate the list, always make sure first segment is included to avoid drift
  394. if i < len(global_encoded) - 2:
  395. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  396. else:
  397. partial_encoded = global_encoded
  398. if use_prompt:
  399. partial_encoded = encoded_prompts + partial_encoded
  400. cat_encoded = torch.cat(partial_encoded, dim=1)
  401. prompt_length = cat_encoded.size(1)
  402. t0 = time.perf_counter()
  403. y = generate(
  404. model=model,
  405. prompt=cat_encoded,
  406. max_new_tokens=max_new_tokens,
  407. im_end_id=im_end_id,
  408. decode_one_token=decode_one_token,
  409. temperature=temperature,
  410. top_p=top_p,
  411. repetition_penalty=repetition_penalty,
  412. )
  413. if sample_idx == 0 and seg_idx == 0 and compile:
  414. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  415. if torch.cuda.is_available():
  416. torch.cuda.synchronize()
  417. t = time.perf_counter() - t0
  418. tokens_generated = y.size(1) - prompt_length
  419. tokens_sec = tokens_generated / t
  420. logger.info(
  421. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  422. )
  423. logger.info(
  424. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  425. )
  426. if torch.cuda.is_available():
  427. logger.info(
  428. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  429. )
  430. # Put the generated tokens
  431. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  432. codes = y[1:, prompt_length:-1].clone()
  433. codes = codes - 1
  434. assert (codes >= 0).all(), f"Negative code found"
  435. decoded = y[:, prompt_length:-1].clone()
  436. # But for global encoding, we should keep the <im_end> token
  437. global_encoded.append(decoded)
  438. assert (codes >= 0).all(), f"Negative code found: {codes}"
  439. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  440. seg_idx += 1
  441. # This indicates the end of the current sample
  442. yield GenerateResponse(action="next")
  443. @dataclass
  444. class WrappedGenerateResponse:
  445. status: Literal["success", "error"]
  446. response: Optional[GenerateResponse | Exception] = None
  447. @dataclass
  448. class GenerateRequest:
  449. request: dict
  450. response_queue: queue.Queue
  451. def launch_thread_safe_queue(
  452. checkpoint_path,
  453. device,
  454. precision,
  455. compile: bool = False,
  456. ):
  457. input_queue = queue.Queue()
  458. init_event = threading.Event()
  459. def worker():
  460. model, decode_one_token = load_model(
  461. checkpoint_path, device, precision, compile=compile
  462. )
  463. with torch.device(device):
  464. model.setup_caches(
  465. max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
  466. )
  467. init_event.set()
  468. while True:
  469. item: GenerateRequest | None = input_queue.get()
  470. if item is None:
  471. break
  472. kwargs = item.request
  473. response_queue = item.response_queue
  474. try:
  475. for chunk in generate_long(
  476. model=model, decode_one_token=decode_one_token, **kwargs
  477. ):
  478. response_queue.put(
  479. WrappedGenerateResponse(status="success", response=chunk)
  480. )
  481. except Exception as e:
  482. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  483. threading.Thread(target=worker, daemon=True).start()
  484. init_event.wait()
  485. return input_queue
  486. @click.command()
  487. @click.option(
  488. "--text",
  489. type=str,
  490. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  491. )
  492. @click.option("--prompt-text", type=str, default=None, multiple=True)
  493. @click.option(
  494. "--prompt-tokens",
  495. type=click.Path(path_type=Path, exists=True),
  496. default=None,
  497. multiple=True,
  498. )
  499. @click.option("--num-samples", type=int, default=1)
  500. @click.option("--max-new-tokens", type=int, default=1024)
  501. @click.option("--top-p", type=float, default=0.7)
  502. @click.option("--repetition-penalty", type=float, default=1.2)
  503. @click.option("--temperature", type=float, default=0.7)
  504. @click.option(
  505. "--checkpoint-path",
  506. type=click.Path(path_type=Path, exists=True),
  507. default="checkpoints/fish-speech-1.4",
  508. )
  509. @click.option("--device", type=str, default="cuda")
  510. @click.option("--compile/--no-compile", default=False)
  511. @click.option("--seed", type=int, default=42)
  512. @click.option("--half/--no-half", default=False)
  513. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  514. @click.option("--chunk-length", type=int, default=100)
  515. def main(
  516. text: str,
  517. prompt_text: Optional[list[str]],
  518. prompt_tokens: Optional[list[Path]],
  519. num_samples: int,
  520. max_new_tokens: int,
  521. top_p: int,
  522. repetition_penalty: float,
  523. temperature: float,
  524. checkpoint_path: Path,
  525. device: str,
  526. compile: bool,
  527. seed: int,
  528. half: bool,
  529. iterative_prompt: bool,
  530. chunk_length: int,
  531. ) -> None:
  532. precision = torch.half if half else torch.bfloat16
  533. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  534. raise ValueError(
  535. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  536. )
  537. logger.info("Loading model ...")
  538. t0 = time.time()
  539. model, decode_one_token = load_model(
  540. checkpoint_path, device, precision, compile=compile
  541. )
  542. with torch.device(device):
  543. model.setup_caches(
  544. max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
  545. )
  546. if torch.cuda.is_available():
  547. torch.cuda.synchronize()
  548. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  549. if prompt_tokens is not None:
  550. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  551. torch.manual_seed(seed)
  552. if torch.cuda.is_available():
  553. torch.cuda.manual_seed(seed)
  554. generator = generate_long(
  555. model=model,
  556. device=device,
  557. decode_one_token=decode_one_token,
  558. text=text,
  559. num_samples=num_samples,
  560. max_new_tokens=max_new_tokens,
  561. top_p=top_p,
  562. repetition_penalty=repetition_penalty,
  563. temperature=temperature,
  564. compile=compile,
  565. iterative_prompt=iterative_prompt,
  566. chunk_length=chunk_length,
  567. prompt_text=prompt_text,
  568. prompt_tokens=prompt_tokens,
  569. )
  570. idx = 0
  571. codes = []
  572. for response in generator:
  573. if response.action == "sample":
  574. codes.append(response.codes)
  575. logger.info(f"Sampled text: {response.text}")
  576. elif response.action == "next":
  577. if codes:
  578. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  579. logger.info(f"Saved codes to codes_{idx}.npy")
  580. logger.info(f"Next sample")
  581. codes = []
  582. idx += 1
  583. else:
  584. logger.error(f"Error: {response}")
  585. if __name__ == "__main__":
  586. main()