generate.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import hydra
  11. import numpy as np
  12. import torch
  13. import torch._dynamo.config
  14. import torch._inductor.config
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  18. from fish_speech.text import clean_text, split_text
  19. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  20. torch._inductor.config.coordinate_descent_tuning = True
  21. torch._inductor.config.triton.unique_kernel_names = True
  22. if hasattr(torch._inductor.config, "fx_graph_cache"):
  23. # Experimental feature to reduce compilation times, will be on by default in future
  24. torch._inductor.config.fx_graph_cache = True
  25. from fish_speech.models.text2semantic.llama import (
  26. BaseTransformer,
  27. DualARTransformer,
  28. NaiveTransformer,
  29. )
  30. def multinomial_sample_one_no_sync(
  31. probs_sort,
  32. ): # Does multinomial sampling without a cuda synchronization
  33. q = torch.empty_like(probs_sort).exponential_(1)
  34. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  35. def logits_to_probs(
  36. logits,
  37. previous_tokens: Optional[torch.Tensor] = None,
  38. temperature: torch.Tensor = 1.0,
  39. top_p: torch.Tensor = 1.0,
  40. repetition_penalty: torch.Tensor = 1.0,
  41. ) -> torch.Tensor:
  42. # Apply repetition penalty
  43. if previous_tokens is not None:
  44. previous_tokens = previous_tokens.long()
  45. score = torch.gather(logits, dim=0, index=previous_tokens)
  46. score = torch.where(
  47. score < 0, score * repetition_penalty, score / repetition_penalty
  48. )
  49. logits.scatter_(dim=0, index=previous_tokens, src=score)
  50. # Apply top-p sampling
  51. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  52. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  53. sorted_indices_to_remove = cum_probs > top_p
  54. sorted_indices_to_remove[0] = False # keep at least one option
  55. indices_to_remove = sorted_indices_to_remove.scatter(
  56. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  57. )
  58. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  59. logits = logits / max(temperature, 1e-5)
  60. probs = torch.nn.functional.softmax(logits, dim=-1)
  61. return probs
  62. def sample(
  63. logits,
  64. previous_tokens: Optional[torch.Tensor] = None,
  65. **sampling_kwargs,
  66. ) -> Tuple[torch.Tensor, torch.Tensor]:
  67. probs = logits_to_probs(
  68. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  69. )
  70. idx_next = multinomial_sample_one_no_sync(probs)
  71. return idx_next, probs
  72. def decode_one_token_ar(
  73. model: DualARTransformer,
  74. x: torch.Tensor,
  75. input_pos: torch.Tensor,
  76. previous_tokens: torch.Tensor = None,
  77. **sampling_kwargs,
  78. ) -> torch.Tensor:
  79. x = model.forward_generate(x, input_pos)
  80. codebooks = [
  81. sample(
  82. x.logits,
  83. previous_tokens=(
  84. previous_tokens[0] if previous_tokens is not None else None
  85. ), # Disable repetition penalty for the token codebook
  86. **sampling_kwargs,
  87. )[0]
  88. ]
  89. x = x.hidden_states
  90. # Cleanup the cache
  91. for layer in model.fast_layers:
  92. layer.attention.kv_cache.k_cache.fill_(0)
  93. layer.attention.kv_cache.v_cache.fill_(0)
  94. for codebook_idx in range(model.config.num_codebooks):
  95. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  96. logits = model.forward_generate_fast(x, input_pos)
  97. a = sample(
  98. logits,
  99. previous_tokens=(
  100. previous_tokens[codebook_idx + 1]
  101. if previous_tokens is not None
  102. else None
  103. ),
  104. **sampling_kwargs,
  105. )[0]
  106. x = model.fast_embeddings(a)
  107. codebooks.append(a)
  108. return torch.stack(codebooks, dim=0)
  109. def decode_one_token_naive(
  110. model: NaiveTransformer,
  111. x: torch.Tensor,
  112. input_pos: torch.Tensor,
  113. previous_tokens: torch.Tensor = None,
  114. **sampling_kwargs,
  115. ) -> torch.Tensor:
  116. x = model.forward_generate(x, input_pos)
  117. codebooks = [
  118. sample(
  119. x.token_logits,
  120. previous_tokens=None, # Disable repetition penalty for the token codebook
  121. **sampling_kwargs,
  122. )[0]
  123. ]
  124. for i in range(model.config.num_codebooks):
  125. codebooks.append(
  126. sample(
  127. x.codebook_logits[:, :, i],
  128. previous_tokens=(
  129. previous_tokens[i + 1] if previous_tokens is not None else None
  130. ),
  131. **sampling_kwargs,
  132. )[0]
  133. )
  134. return torch.stack(codebooks, dim=0)
  135. def decode_n_tokens(
  136. model: NaiveTransformer,
  137. cur_token: torch.Tensor,
  138. input_pos: torch.Tensor,
  139. num_new_tokens: int,
  140. im_end_id: int = 4,
  141. decode_one_token=decode_one_token_naive,
  142. **sampling_kwargs,
  143. ):
  144. previous_tokens = torch.zeros(
  145. (model.config.num_codebooks + 1, model.config.max_seq_len),
  146. dtype=torch.int,
  147. device=cur_token.device,
  148. )
  149. for i in tqdm(range(num_new_tokens)):
  150. # We need to get windowed repeat penalty
  151. win_size = 16
  152. if i < win_size:
  153. window = previous_tokens[:, :win_size]
  154. else:
  155. window = previous_tokens[:, i - win_size : i]
  156. with (
  157. torch.backends.cuda.sdp_kernel(
  158. enable_flash=False, enable_mem_efficient=False, enable_math=True
  159. )
  160. if torch.cuda.is_available()
  161. else nullcontext()
  162. ): # Actually better for Inductor to codegen attention here
  163. next_token = decode_one_token(
  164. model=model,
  165. x=cur_token,
  166. input_pos=input_pos,
  167. previous_tokens=window,
  168. **sampling_kwargs,
  169. )
  170. input_pos += 1
  171. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  172. previous_tokens[:, i : i + 1] = next_token.view(
  173. model.config.num_codebooks + 1, -1
  174. )
  175. if cur_token[0, 0, -1] == im_end_id:
  176. break
  177. return previous_tokens[:, : i + 1]
  178. @torch.no_grad()
  179. @torch.inference_mode()
  180. def generate(
  181. *,
  182. model: NaiveTransformer,
  183. prompt: torch.Tensor,
  184. max_new_tokens: int,
  185. im_end_id: int = 4,
  186. decode_one_token=decode_one_token_naive,
  187. **sampling_kwargs,
  188. ) -> torch.Tensor:
  189. """
  190. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  191. """
  192. # create an empty tensor of the expected final shape and fill in the current tokens
  193. T = prompt.size(1)
  194. if max_new_tokens:
  195. if T + max_new_tokens > model.config.max_seq_len:
  196. max_new_tokens = model.config.max_seq_len - T
  197. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  198. T_new = T + max_new_tokens
  199. else:
  200. T_new = model.config.max_seq_len
  201. max_new_tokens = T_new - T
  202. device, dtype = prompt.device, prompt.dtype
  203. with torch.device(device):
  204. model.setup_caches(
  205. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  206. )
  207. codebook_dim = 1 + model.config.num_codebooks
  208. # create an empty tensor of the expected final shape and fill in the current tokens
  209. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  210. empty[:, :T] = prompt
  211. seq = empty
  212. input_pos = torch.arange(0, T, device=device)
  213. # Use non-accelerated version for now, to avoid compilation overhead
  214. prefill_decode = (
  215. decode_one_token_naive
  216. if isinstance(model, NaiveTransformer)
  217. else decode_one_token_ar
  218. )
  219. next_token = prefill_decode(
  220. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  221. )
  222. seq[:, T : T + 1] = next_token
  223. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  224. x = decode_n_tokens(
  225. model,
  226. next_token.view(1, codebook_dim, -1),
  227. input_pos,
  228. max_new_tokens - 1,
  229. im_end_id=im_end_id,
  230. decode_one_token=decode_one_token,
  231. **sampling_kwargs,
  232. )
  233. # x = torch.cat(generated_tokens, dim=1)
  234. seq = seq[:, : T + 1 + x.size(1)]
  235. seq[:, T + 1 :] = x
  236. return seq
  237. def encode_tokens(
  238. tokenizer,
  239. string,
  240. device="cuda",
  241. prompt_tokens=None,
  242. num_codebooks=4,
  243. ):
  244. string = clean_text(string)
  245. string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
  246. new_tokens = tokenizer.encode(
  247. string,
  248. add_special_tokens=False,
  249. max_length=10**6,
  250. truncation=False,
  251. )
  252. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  253. # Codebooks
  254. zeros = (
  255. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  256. * CODEBOOK_PAD_TOKEN_ID
  257. )
  258. prompt = torch.cat((tokens, zeros), dim=0)
  259. if prompt_tokens is None:
  260. return prompt
  261. # Get prompt tokens
  262. if prompt_tokens.ndim == 3:
  263. assert (
  264. prompt_tokens.shape[0] == 1
  265. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  266. prompt_tokens = prompt_tokens[0]
  267. assert prompt_tokens.ndim == 2
  268. data = prompt_tokens + 1
  269. if prompt_tokens.shape[0] > num_codebooks:
  270. logger.warning(
  271. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  272. )
  273. data = data[:num_codebooks]
  274. # Add pad token for each codebook
  275. data = torch.cat(
  276. (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
  277. dim=1,
  278. )
  279. # Since 1.0, we use <|semantic|>
  280. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  281. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  282. main_token_ids = (
  283. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  284. )
  285. main_token_ids[0, -1] = end_token_id
  286. data = torch.cat((main_token_ids, data), dim=0)
  287. prompt = torch.cat((prompt, data), dim=1)
  288. return prompt
  289. def load_model(checkpoint_path, device, precision, compile=False):
  290. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  291. checkpoint_path, load_weights=True
  292. )
  293. model = model.to(device=device, dtype=precision)
  294. logger.info(f"Restored model from checkpoint")
  295. if isinstance(model, DualARTransformer):
  296. decode_one_token = decode_one_token_ar
  297. logger.info("Using DualARTransformer")
  298. else:
  299. decode_one_token = decode_one_token_naive
  300. logger.info("Using NaiveTransformer")
  301. if compile:
  302. logger.info("Compiling function...")
  303. decode_one_token = torch.compile(
  304. decode_one_token,
  305. fullgraph=True,
  306. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  307. mode="reduce-overhead" if torch.cuda.is_available() else None,
  308. )
  309. return model.eval(), decode_one_token
  310. @dataclass
  311. class GenerateResponse:
  312. action: Literal["sample", "next"]
  313. codes: Optional[torch.Tensor] = None
  314. text: Optional[str] = None
  315. def generate_long(
  316. *,
  317. model,
  318. device: str | torch.device,
  319. decode_one_token: callable,
  320. text: str,
  321. num_samples: int = 1,
  322. max_new_tokens: int = 0,
  323. top_p: int = 0.7,
  324. repetition_penalty: float = 1.5,
  325. temperature: float = 0.7,
  326. compile: bool = False,
  327. iterative_prompt: bool = True,
  328. max_length: int = 2048,
  329. chunk_length: int = 150,
  330. prompt_text: Optional[str | list[str]] = None,
  331. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  332. ):
  333. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  334. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  335. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  336. use_prompt = prompt_text is not None and prompt_tokens is not None
  337. if use_prompt and isinstance(prompt_text, str):
  338. prompt_text = [prompt_text]
  339. prompt_tokens = [prompt_tokens]
  340. assert use_prompt is False or len(prompt_text) == len(
  341. prompt_tokens
  342. ), "Prompt text and tokens must have the same length"
  343. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  344. tokenizer = model.tokenizer
  345. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  346. encoded = []
  347. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  348. encoded_prompts = []
  349. if use_prompt:
  350. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  351. encoded_prompts.append(
  352. encode_tokens(
  353. tokenizer,
  354. string=t,
  355. device=device,
  356. prompt_tokens=c,
  357. num_codebooks=model.config.num_codebooks,
  358. )
  359. )
  360. for idx, text in enumerate(texts):
  361. encoded.append(
  362. encode_tokens(
  363. tokenizer,
  364. string=text,
  365. device=device,
  366. num_codebooks=model.config.num_codebooks,
  367. )
  368. )
  369. logger.info(f"Encoded text: {text}")
  370. # Move temperature, top_p, repetition_penalty to device
  371. # This is important so that changing params doesn't trigger recompile
  372. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  373. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  374. repetition_penalty = torch.tensor(
  375. repetition_penalty, device=device, dtype=torch.float
  376. )
  377. for sample_idx in range(num_samples):
  378. if torch.cuda.is_available():
  379. torch.cuda.synchronize()
  380. global_encoded = []
  381. seg_idx = 0
  382. while seg_idx < len(encoded):
  383. logger.info(
  384. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  385. )
  386. seg = encoded[seg_idx]
  387. global_encoded.append(seg)
  388. lengths = reversed([seg.size(1) for seg in global_encoded])
  389. # Pick last 2000 tokens
  390. count = 0
  391. for i, length in enumerate(lengths):
  392. count += length
  393. if count + length > max_length - 1024 - sum(
  394. t.shape[1] for t in encoded_prompts
  395. ):
  396. break
  397. if i != 0 and i % 2 == 0:
  398. i -= 1
  399. # Rotate the list, always make sure first segment is included to avoid drift
  400. if i < len(global_encoded) - 2:
  401. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  402. else:
  403. partial_encoded = global_encoded
  404. if use_prompt:
  405. partial_encoded = encoded_prompts + partial_encoded
  406. cat_encoded = torch.cat(partial_encoded, dim=1)
  407. prompt_length = cat_encoded.size(1)
  408. t0 = time.perf_counter()
  409. y = generate(
  410. model=model,
  411. prompt=cat_encoded,
  412. max_new_tokens=max_new_tokens,
  413. im_end_id=im_end_id,
  414. decode_one_token=decode_one_token,
  415. temperature=temperature,
  416. top_p=top_p,
  417. repetition_penalty=repetition_penalty,
  418. )
  419. if sample_idx == 0 and seg_idx == 0 and compile:
  420. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  421. if torch.cuda.is_available():
  422. torch.cuda.synchronize()
  423. t = time.perf_counter() - t0
  424. tokens_generated = y.size(1) - prompt_length
  425. tokens_sec = tokens_generated / t
  426. logger.info(
  427. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  428. )
  429. logger.info(
  430. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  431. )
  432. if torch.cuda.is_available():
  433. logger.info(
  434. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  435. )
  436. # Put the generated tokens
  437. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  438. codes = y[1:, prompt_length:-1].clone()
  439. codes = codes - 1
  440. assert (codes >= 0).all(), f"Negative code found"
  441. decoded = y[:, prompt_length:-1].clone()
  442. # But for global encoding, we should keep the <im_end> token
  443. global_encoded.append(decoded)
  444. assert (codes >= 0).all(), f"Negative code found: {codes}"
  445. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  446. seg_idx += 1
  447. # This indicates the end of the current sample
  448. yield GenerateResponse(action="next")
  449. @dataclass
  450. class WrappedGenerateResponse:
  451. status: Literal["success", "error"]
  452. response: Optional[GenerateResponse | Exception] = None
  453. @dataclass
  454. class GenerateRequest:
  455. request: dict
  456. response_queue: queue.Queue
  457. def launch_thread_safe_queue(
  458. checkpoint_path,
  459. device,
  460. precision,
  461. compile: bool = False,
  462. ):
  463. input_queue = queue.Queue()
  464. init_event = threading.Event()
  465. def worker():
  466. model, decode_one_token = load_model(
  467. checkpoint_path, device, precision, compile=compile
  468. )
  469. init_event.set()
  470. while True:
  471. item: GenerateRequest | None = input_queue.get()
  472. if item is None:
  473. break
  474. kwargs = item.request
  475. response_queue = item.response_queue
  476. try:
  477. for chunk in generate_long(
  478. model=model, decode_one_token=decode_one_token, **kwargs
  479. ):
  480. response_queue.put(
  481. WrappedGenerateResponse(status="success", response=chunk)
  482. )
  483. except Exception as e:
  484. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  485. threading.Thread(target=worker, daemon=True).start()
  486. init_event.wait()
  487. return input_queue
  488. @click.command()
  489. @click.option(
  490. "--text",
  491. type=str,
  492. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  493. )
  494. @click.option("--prompt-text", type=str, default=None, multiple=True)
  495. @click.option(
  496. "--prompt-tokens",
  497. type=click.Path(path_type=Path, exists=True),
  498. default=None,
  499. multiple=True,
  500. )
  501. @click.option("--num-samples", type=int, default=1)
  502. @click.option("--max-new-tokens", type=int, default=0)
  503. @click.option("--top-p", type=float, default=0.7)
  504. @click.option("--repetition-penalty", type=float, default=1.2)
  505. @click.option("--temperature", type=float, default=0.7)
  506. @click.option(
  507. "--checkpoint-path",
  508. type=click.Path(path_type=Path, exists=True),
  509. default="checkpoints/fish-speech-1.2-sft",
  510. )
  511. @click.option("--device", type=str, default="cuda")
  512. @click.option("--compile/--no-compile", default=False)
  513. @click.option("--seed", type=int, default=42)
  514. @click.option("--half/--no-half", default=False)
  515. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  516. @click.option("--chunk-length", type=int, default=100)
  517. def main(
  518. text: str,
  519. prompt_text: Optional[list[str]],
  520. prompt_tokens: Optional[list[Path]],
  521. num_samples: int,
  522. max_new_tokens: int,
  523. top_p: int,
  524. repetition_penalty: float,
  525. temperature: float,
  526. checkpoint_path: Path,
  527. device: str,
  528. compile: bool,
  529. seed: int,
  530. half: bool,
  531. iterative_prompt: bool,
  532. chunk_length: int,
  533. ) -> None:
  534. precision = torch.half if half else torch.bfloat16
  535. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  536. raise ValueError(
  537. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  538. )
  539. logger.info("Loading model ...")
  540. t0 = time.time()
  541. model, decode_one_token = load_model(
  542. checkpoint_path, device, precision, compile=compile
  543. )
  544. if torch.cuda.is_available():
  545. torch.cuda.synchronize()
  546. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  547. if prompt_tokens is not None:
  548. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  549. torch.manual_seed(seed)
  550. if torch.cuda.is_available():
  551. torch.cuda.manual_seed(seed)
  552. generator = generate_long(
  553. model=model,
  554. device=device,
  555. decode_one_token=decode_one_token,
  556. text=text,
  557. num_samples=num_samples,
  558. max_new_tokens=max_new_tokens,
  559. top_p=top_p,
  560. repetition_penalty=repetition_penalty,
  561. temperature=temperature,
  562. compile=compile,
  563. iterative_prompt=iterative_prompt,
  564. chunk_length=chunk_length,
  565. prompt_text=prompt_text,
  566. prompt_tokens=prompt_tokens,
  567. )
  568. idx = 0
  569. codes = []
  570. for response in generator:
  571. if response.action == "sample":
  572. codes.append(response.codes)
  573. logger.info(f"Sampled text: {response.text}")
  574. elif response.action == "next":
  575. if codes:
  576. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  577. logger.info(f"Saved codes to codes_{idx}.npy")
  578. logger.info(f"Next sample")
  579. codes = []
  580. idx += 1
  581. else:
  582. logger.error(f"Error: {response}")
  583. if __name__ == "__main__":
  584. main()