generate.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import hydra
  11. import numpy as np
  12. import torch
  13. import torch._dynamo.config
  14. import torch._inductor.config
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  18. from fish_speech.text import clean_text, split_text
  19. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  20. torch._inductor.config.coordinate_descent_tuning = True
  21. torch._inductor.config.triton.unique_kernel_names = True
  22. if hasattr(torch._inductor.config, "fx_graph_cache"):
  23. # Experimental feature to reduce compilation times, will be on by default in future
  24. torch._inductor.config.fx_graph_cache = True
  25. from fish_speech.models.text2semantic.llama import (
  26. BaseTransformer,
  27. DualARTransformer,
  28. NaiveTransformer,
  29. )
  30. def multinomial_sample_one_no_sync(
  31. probs_sort,
  32. ): # Does multinomial sampling without a cuda synchronization
  33. q = torch.empty_like(probs_sort).exponential_(1)
  34. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  35. def logits_to_probs(
  36. logits,
  37. previous_tokens: Optional[torch.Tensor] = None,
  38. temperature: torch.Tensor = 1.0,
  39. top_p: torch.Tensor = 1.0,
  40. repetition_penalty: torch.Tensor = 1.0,
  41. ) -> torch.Tensor:
  42. # Apply repetition penalty
  43. if previous_tokens is not None:
  44. previous_tokens = previous_tokens.long()
  45. score = torch.gather(logits, dim=0, index=previous_tokens)
  46. score = torch.where(
  47. score < 0, score * repetition_penalty, score / repetition_penalty
  48. )
  49. logits.scatter_(dim=0, index=previous_tokens, src=score)
  50. # Apply top-p sampling
  51. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  52. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  53. sorted_indices_to_remove = cum_probs > top_p
  54. sorted_indices_to_remove[0] = False # keep at least one option
  55. indices_to_remove = sorted_indices_to_remove.scatter(
  56. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  57. )
  58. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  59. logits = logits / max(temperature, 1e-5)
  60. probs = torch.nn.functional.softmax(logits, dim=-1)
  61. return probs
  62. def sample(
  63. logits,
  64. previous_tokens: Optional[torch.Tensor] = None,
  65. **sampling_kwargs,
  66. ) -> Tuple[torch.Tensor, torch.Tensor]:
  67. probs = logits_to_probs(
  68. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  69. )
  70. idx_next = multinomial_sample_one_no_sync(probs)
  71. return idx_next, probs
  72. def decode_one_token_ar(
  73. model: DualARTransformer,
  74. x: torch.Tensor,
  75. input_pos: torch.Tensor,
  76. previous_tokens: torch.Tensor = None,
  77. **sampling_kwargs,
  78. ) -> torch.Tensor:
  79. x = model.forward_generate(x, input_pos)
  80. sampling_kwargs_main = sampling_kwargs.copy()
  81. sampling_kwargs_main["temperature"] = 0.1
  82. sampling_kwargs_main["top_p"] = 0.1
  83. sampling_kwargs_main["repetition_penalty"] = 1.0
  84. codebooks = [
  85. sample(
  86. x.logits,
  87. previous_tokens=None, # Disable repetition penalty for the token codebook
  88. **sampling_kwargs_main,
  89. )[0]
  90. ]
  91. x = x.hidden_states
  92. # Cleanup the cache
  93. for layer in model.fast_layers:
  94. layer.attention.kv_cache.k_cache.fill_(0)
  95. layer.attention.kv_cache.v_cache.fill_(0)
  96. for codebook_idx in range(model.config.num_codebooks):
  97. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  98. logits = model.forward_generate_fast(x, input_pos)
  99. a = sample(
  100. logits,
  101. previous_tokens=(
  102. previous_tokens[codebook_idx + 1]
  103. if previous_tokens is not None
  104. else None
  105. ),
  106. **sampling_kwargs,
  107. )[0]
  108. x = model.fast_embeddings(a)
  109. codebooks.append(a)
  110. return torch.stack(codebooks, dim=0)
  111. def decode_one_token_naive(
  112. model: NaiveTransformer,
  113. x: torch.Tensor,
  114. input_pos: torch.Tensor,
  115. previous_tokens: torch.Tensor = None,
  116. **sampling_kwargs,
  117. ) -> torch.Tensor:
  118. x = model.forward_generate(x, input_pos)
  119. sampling_kwargs_main = sampling_kwargs.copy()
  120. sampling_kwargs_main["temperature"] = 0.1
  121. sampling_kwargs_main["top_p"] = 0.1
  122. sampling_kwargs_main["repetition_penalty"] = 1.0
  123. codebooks = [
  124. sample(
  125. x.logits,
  126. previous_tokens=None, # Disable repetition penalty for the token codebook
  127. **sampling_kwargs_main,
  128. )[0]
  129. ]
  130. for i in range(model.config.num_codebooks):
  131. codebooks.append(
  132. sample(
  133. x.codebook_logits[:, :, i],
  134. previous_tokens=(
  135. previous_tokens[i + 1] if previous_tokens is not None else None
  136. ),
  137. **sampling_kwargs,
  138. )[0]
  139. )
  140. return torch.stack(codebooks, dim=0)
  141. def decode_n_tokens(
  142. model: NaiveTransformer,
  143. cur_token: torch.Tensor,
  144. input_pos: torch.Tensor,
  145. num_new_tokens: int,
  146. im_end_id: int = 4,
  147. decode_one_token=decode_one_token_naive,
  148. **sampling_kwargs,
  149. ):
  150. previous_tokens = torch.zeros(
  151. (model.config.num_codebooks + 1, model.config.max_seq_len),
  152. dtype=torch.int,
  153. device=cur_token.device,
  154. )
  155. for i in tqdm(range(num_new_tokens)):
  156. # We need to get windowed repeat penalty
  157. win_size = 16
  158. if i < win_size:
  159. window = previous_tokens[:, :win_size]
  160. else:
  161. window = previous_tokens[:, i - win_size : i]
  162. with (
  163. torch.backends.cuda.sdp_kernel(
  164. enable_flash=False, enable_mem_efficient=False, enable_math=True
  165. )
  166. if torch.cuda.is_available()
  167. else nullcontext()
  168. ): # Actually better for Inductor to codegen attention here
  169. next_token = decode_one_token(
  170. model=model,
  171. x=cur_token,
  172. input_pos=input_pos,
  173. previous_tokens=window,
  174. **sampling_kwargs,
  175. )
  176. input_pos += 1
  177. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  178. previous_tokens[:, i : i + 1] = next_token.view(
  179. model.config.num_codebooks + 1, -1
  180. )
  181. if cur_token[0, 0, -1] == im_end_id:
  182. break
  183. return previous_tokens[:, : i + 1]
  184. @torch.no_grad()
  185. @torch.inference_mode()
  186. def generate(
  187. *,
  188. model: NaiveTransformer,
  189. prompt: torch.Tensor,
  190. max_new_tokens: int,
  191. im_end_id: int = 4,
  192. decode_one_token=decode_one_token_naive,
  193. **sampling_kwargs,
  194. ) -> torch.Tensor:
  195. """
  196. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  197. """
  198. # create an empty tensor of the expected final shape and fill in the current tokens
  199. T = prompt.size(1)
  200. device, dtype = prompt.device, prompt.dtype
  201. codebook_dim = 1 + model.config.num_codebooks
  202. # create an empty tensor of the expected final shape and fill in the current tokens
  203. empty = torch.empty(
  204. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  205. )
  206. empty[:, :T] = prompt
  207. seq = empty
  208. input_pos = torch.arange(0, T, device=device)
  209. # Use non-accelerated version for now, to avoid compilation overhead
  210. prefill_decode = (
  211. decode_one_token_naive
  212. if isinstance(model, NaiveTransformer)
  213. else decode_one_token_ar
  214. )
  215. next_token = prefill_decode(
  216. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  217. )
  218. seq[:, T : T + 1] = next_token
  219. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  220. x = decode_n_tokens(
  221. model,
  222. next_token.view(1, codebook_dim, -1),
  223. input_pos,
  224. max_new_tokens - 1,
  225. im_end_id=im_end_id,
  226. decode_one_token=decode_one_token,
  227. **sampling_kwargs,
  228. )
  229. # x = torch.cat(generated_tokens, dim=1)
  230. seq = seq[:, : T + 1 + x.size(1)]
  231. seq[:, T + 1 :] = x
  232. return seq
  233. def encode_tokens(
  234. tokenizer,
  235. string,
  236. device="cuda",
  237. prompt_tokens=None,
  238. num_codebooks=4,
  239. ):
  240. string = clean_text(string)
  241. string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
  242. new_tokens = tokenizer.encode(
  243. string,
  244. add_special_tokens=False,
  245. max_length=10**6,
  246. truncation=False,
  247. )
  248. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  249. # Codebooks
  250. zeros = (
  251. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  252. * CODEBOOK_PAD_TOKEN_ID
  253. )
  254. prompt = torch.cat((tokens, zeros), dim=0)
  255. if prompt_tokens is None:
  256. return prompt
  257. # Get prompt tokens
  258. if prompt_tokens.ndim == 3:
  259. assert (
  260. prompt_tokens.shape[0] == 1
  261. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  262. prompt_tokens = prompt_tokens[0]
  263. assert prompt_tokens.ndim == 2
  264. data = prompt_tokens + 1
  265. if prompt_tokens.shape[0] > num_codebooks:
  266. logger.warning(
  267. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  268. )
  269. data = data[:num_codebooks]
  270. # Add pad token for each codebook
  271. data = torch.cat(
  272. (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
  273. dim=1,
  274. )
  275. # Since 1.0, we use <|semantic|>
  276. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  277. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  278. main_token_ids = (
  279. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  280. )
  281. main_token_ids[0, -1] = end_token_id
  282. data = torch.cat((main_token_ids, data), dim=0)
  283. prompt = torch.cat((prompt, data), dim=1)
  284. return prompt
  285. def load_model(checkpoint_path, device, precision, compile=False):
  286. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  287. checkpoint_path, load_weights=True
  288. )
  289. model = model.to(device=device, dtype=precision)
  290. logger.info(f"Restored model from checkpoint")
  291. if isinstance(model, DualARTransformer):
  292. decode_one_token = decode_one_token_ar
  293. logger.info("Using DualARTransformer")
  294. else:
  295. decode_one_token = decode_one_token_naive
  296. logger.info("Using NaiveTransformer")
  297. if compile:
  298. logger.info("Compiling function...")
  299. decode_one_token = torch.compile(
  300. decode_one_token,
  301. fullgraph=True,
  302. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  303. mode="reduce-overhead" if torch.cuda.is_available() else None,
  304. )
  305. return model.eval(), decode_one_token
  306. @dataclass
  307. class GenerateResponse:
  308. action: Literal["sample", "next"]
  309. codes: Optional[torch.Tensor] = None
  310. text: Optional[str] = None
  311. def generate_long(
  312. *,
  313. model,
  314. device: str | torch.device,
  315. decode_one_token: callable,
  316. text: str,
  317. num_samples: int = 1,
  318. max_new_tokens: int = 0,
  319. top_p: int = 0.7,
  320. repetition_penalty: float = 1.5,
  321. temperature: float = 0.7,
  322. compile: bool = False,
  323. iterative_prompt: bool = True,
  324. max_length: int = 2048,
  325. chunk_length: int = 150,
  326. prompt_text: Optional[str | list[str]] = None,
  327. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  328. ):
  329. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  330. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  331. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  332. use_prompt = prompt_text is not None and prompt_tokens is not None
  333. if use_prompt and isinstance(prompt_text, str):
  334. prompt_text = [prompt_text]
  335. prompt_tokens = [prompt_tokens]
  336. assert use_prompt is False or len(prompt_text) == len(
  337. prompt_tokens
  338. ), "Prompt text and tokens must have the same length"
  339. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  340. tokenizer = model.tokenizer
  341. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  342. encoded = []
  343. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  344. encoded_prompts = []
  345. if use_prompt:
  346. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  347. encoded_prompts.append(
  348. encode_tokens(
  349. tokenizer,
  350. string=t,
  351. device=device,
  352. prompt_tokens=c,
  353. num_codebooks=model.config.num_codebooks,
  354. )
  355. )
  356. for idx, text in enumerate(texts):
  357. encoded.append(
  358. encode_tokens(
  359. tokenizer,
  360. string=text,
  361. device=device,
  362. num_codebooks=model.config.num_codebooks,
  363. )
  364. )
  365. logger.info(f"Encoded text: {text}")
  366. # Move temperature, top_p, repetition_penalty to device
  367. # This is important so that changing params doesn't trigger recompile
  368. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  369. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  370. repetition_penalty = torch.tensor(
  371. repetition_penalty, device=device, dtype=torch.float
  372. )
  373. for sample_idx in range(num_samples):
  374. if torch.cuda.is_available():
  375. torch.cuda.synchronize()
  376. global_encoded = []
  377. seg_idx = 0
  378. while seg_idx < len(encoded):
  379. logger.info(
  380. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  381. )
  382. seg = encoded[seg_idx]
  383. global_encoded.append(seg)
  384. lengths = reversed([seg.size(1) for seg in global_encoded])
  385. # Pick last 2000 tokens
  386. count = 0
  387. for i, length in enumerate(lengths):
  388. count += length
  389. if count + length > max_length - 1024 - sum(
  390. t.shape[1] for t in encoded_prompts
  391. ):
  392. break
  393. if i != 0 and i % 2 == 0:
  394. i -= 1
  395. # Rotate the list, always make sure first segment is included to avoid drift
  396. if i < len(global_encoded) - 2:
  397. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  398. else:
  399. partial_encoded = global_encoded
  400. if use_prompt:
  401. partial_encoded = encoded_prompts + partial_encoded
  402. cat_encoded = torch.cat(partial_encoded, dim=1)
  403. prompt_length = cat_encoded.size(1)
  404. t0 = time.perf_counter()
  405. y = generate(
  406. model=model,
  407. prompt=cat_encoded,
  408. max_new_tokens=max_new_tokens,
  409. im_end_id=im_end_id,
  410. decode_one_token=decode_one_token,
  411. temperature=temperature,
  412. top_p=top_p,
  413. repetition_penalty=repetition_penalty,
  414. )
  415. if sample_idx == 0 and seg_idx == 0 and compile:
  416. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  417. if torch.cuda.is_available():
  418. torch.cuda.synchronize()
  419. t = time.perf_counter() - t0
  420. tokens_generated = y.size(1) - prompt_length
  421. tokens_sec = tokens_generated / t
  422. logger.info(
  423. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  424. )
  425. logger.info(
  426. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  427. )
  428. if torch.cuda.is_available():
  429. logger.info(
  430. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  431. )
  432. # Put the generated tokens
  433. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  434. codes = y[1:, prompt_length:-1].clone()
  435. codes = codes - 1
  436. assert (codes >= 0).all(), f"Negative code found"
  437. decoded = y[:, prompt_length:-1].clone()
  438. # But for global encoding, we should keep the <im_end> token
  439. global_encoded.append(decoded)
  440. assert (codes >= 0).all(), f"Negative code found: {codes}"
  441. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  442. seg_idx += 1
  443. # This indicates the end of the current sample
  444. yield GenerateResponse(action="next")
  445. @dataclass
  446. class WrappedGenerateResponse:
  447. status: Literal["success", "error"]
  448. response: Optional[GenerateResponse | Exception] = None
  449. @dataclass
  450. class GenerateRequest:
  451. request: dict
  452. response_queue: queue.Queue
  453. def launch_thread_safe_queue(
  454. checkpoint_path,
  455. device,
  456. precision,
  457. compile: bool = False,
  458. ):
  459. input_queue = queue.Queue()
  460. init_event = threading.Event()
  461. def worker():
  462. model, decode_one_token = load_model(
  463. checkpoint_path, device, precision, compile=compile
  464. )
  465. with torch.device(device):
  466. model.setup_caches(
  467. max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
  468. )
  469. init_event.set()
  470. while True:
  471. item: GenerateRequest | None = input_queue.get()
  472. if item is None:
  473. break
  474. kwargs = item.request
  475. response_queue = item.response_queue
  476. try:
  477. for chunk in generate_long(
  478. model=model, decode_one_token=decode_one_token, **kwargs
  479. ):
  480. response_queue.put(
  481. WrappedGenerateResponse(status="success", response=chunk)
  482. )
  483. except Exception as e:
  484. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  485. threading.Thread(target=worker, daemon=True).start()
  486. init_event.wait()
  487. return input_queue
  488. @click.command()
  489. @click.option(
  490. "--text",
  491. type=str,
  492. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  493. )
  494. @click.option("--prompt-text", type=str, default=None, multiple=True)
  495. @click.option(
  496. "--prompt-tokens",
  497. type=click.Path(path_type=Path, exists=True),
  498. default=None,
  499. multiple=True,
  500. )
  501. @click.option("--num-samples", type=int, default=1)
  502. @click.option("--max-new-tokens", type=int, default=1024)
  503. @click.option("--top-p", type=float, default=0.7)
  504. @click.option("--repetition-penalty", type=float, default=1.2)
  505. @click.option("--temperature", type=float, default=0.7)
  506. @click.option(
  507. "--checkpoint-path",
  508. type=click.Path(path_type=Path, exists=True),
  509. default="checkpoints/fish-speech-1.4",
  510. )
  511. @click.option("--device", type=str, default="cuda")
  512. @click.option("--compile/--no-compile", default=False)
  513. @click.option("--seed", type=int, default=42)
  514. @click.option("--half/--no-half", default=False)
  515. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  516. @click.option("--chunk-length", type=int, default=100)
  517. def main(
  518. text: str,
  519. prompt_text: Optional[list[str]],
  520. prompt_tokens: Optional[list[Path]],
  521. num_samples: int,
  522. max_new_tokens: int,
  523. top_p: int,
  524. repetition_penalty: float,
  525. temperature: float,
  526. checkpoint_path: Path,
  527. device: str,
  528. compile: bool,
  529. seed: int,
  530. half: bool,
  531. iterative_prompt: bool,
  532. chunk_length: int,
  533. ) -> None:
  534. precision = torch.half if half else torch.bfloat16
  535. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  536. raise ValueError(
  537. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  538. )
  539. logger.info("Loading model ...")
  540. t0 = time.time()
  541. model, decode_one_token = load_model(
  542. checkpoint_path, device, precision, compile=compile
  543. )
  544. with torch.device(device):
  545. model.setup_caches(
  546. max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
  547. )
  548. if torch.cuda.is_available():
  549. torch.cuda.synchronize()
  550. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  551. if prompt_tokens is not None:
  552. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  553. torch.manual_seed(seed)
  554. if torch.cuda.is_available():
  555. torch.cuda.manual_seed(seed)
  556. generator = generate_long(
  557. model=model,
  558. device=device,
  559. decode_one_token=decode_one_token,
  560. text=text,
  561. num_samples=num_samples,
  562. max_new_tokens=max_new_tokens,
  563. top_p=top_p,
  564. repetition_penalty=repetition_penalty,
  565. temperature=temperature,
  566. compile=compile,
  567. iterative_prompt=iterative_prompt,
  568. chunk_length=chunk_length,
  569. prompt_text=prompt_text,
  570. prompt_tokens=prompt_tokens,
  571. )
  572. idx = 0
  573. codes = []
  574. for response in generator:
  575. if response.action == "sample":
  576. codes.append(response.codes)
  577. logger.info(f"Sampled text: {response.text}")
  578. elif response.action == "next":
  579. if codes:
  580. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  581. logger.info(f"Saved codes to codes_{idx}.npy")
  582. logger.info(f"Next sample")
  583. codes = []
  584. idx += 1
  585. else:
  586. logger.error(f"Error: {response}")
  587. if __name__ == "__main__":
  588. main()