generate.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from typing import Literal, Optional, Tuple, Union
  8. import click
  9. import hydra
  10. import numpy as np
  11. import torch
  12. import torch._dynamo.config
  13. import torch._inductor.config
  14. from loguru import logger
  15. from tqdm import tqdm
  16. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  17. from fish_speech.text import clean_text, split_text
  18. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  19. torch._inductor.config.coordinate_descent_tuning = True
  20. torch._inductor.config.triton.unique_kernel_names = True
  21. if hasattr(torch._inductor.config, "fx_graph_cache"):
  22. # Experimental feature to reduce compilation times, will be on by default in future
  23. torch._inductor.config.fx_graph_cache = True
  24. from fish_speech.models.text2semantic.llama import (
  25. BaseTransformer,
  26. DualARTransformer,
  27. NaiveTransformer,
  28. )
  29. def multinomial_sample_one_no_sync(
  30. probs_sort,
  31. ): # Does multinomial sampling without a cuda synchronization
  32. q = torch.empty_like(probs_sort).exponential_(1)
  33. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  34. def logits_to_probs(
  35. logits,
  36. previous_tokens: Optional[torch.Tensor] = None,
  37. temperature: torch.Tensor = 1.0,
  38. top_p: torch.Tensor = 1.0,
  39. repetition_penalty: torch.Tensor = 1.0,
  40. ) -> torch.Tensor:
  41. # Apply repetition penalty
  42. if previous_tokens is not None:
  43. previous_tokens = previous_tokens.long()
  44. score = torch.gather(logits, dim=0, index=previous_tokens)
  45. score = torch.where(
  46. score < 0, score * repetition_penalty, score / repetition_penalty
  47. )
  48. logits.scatter_(dim=0, index=previous_tokens, src=score)
  49. # Apply top-p sampling
  50. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  51. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  52. sorted_indices_to_remove = cum_probs > top_p
  53. sorted_indices_to_remove[0] = False # keep at least one option
  54. indices_to_remove = sorted_indices_to_remove.scatter(
  55. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  56. )
  57. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  58. logits = logits / max(temperature, 1e-5)
  59. probs = torch.nn.functional.softmax(logits, dim=-1)
  60. return probs
  61. def sample(
  62. logits,
  63. previous_tokens: Optional[torch.Tensor] = None,
  64. **sampling_kwargs,
  65. ) -> Tuple[torch.Tensor, torch.Tensor]:
  66. probs = logits_to_probs(
  67. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  68. )
  69. idx_next = multinomial_sample_one_no_sync(probs)
  70. return idx_next, probs
  71. def decode_one_token_ar(
  72. model: DualARTransformer,
  73. x: torch.Tensor,
  74. input_pos: torch.Tensor,
  75. previous_tokens: torch.Tensor = None,
  76. **sampling_kwargs,
  77. ) -> torch.Tensor:
  78. x = model.forward_generate(x, input_pos)
  79. codebooks = [
  80. sample(
  81. x.logits,
  82. previous_tokens=(
  83. previous_tokens[0] if previous_tokens is not None else None
  84. ), # Disable repetition penalty for the token codebook
  85. **sampling_kwargs,
  86. )[0]
  87. ]
  88. x = x.hidden_states
  89. # Cleanup the cache
  90. for layer in model.fast_layers:
  91. layer.attention.kv_cache.k_cache.fill_(0)
  92. layer.attention.kv_cache.v_cache.fill_(0)
  93. for codebook_idx in range(model.config.num_codebooks):
  94. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  95. logits = model.forward_generate_fast(x, input_pos)
  96. a = sample(
  97. logits,
  98. previous_tokens=(
  99. previous_tokens[codebook_idx + 1]
  100. if previous_tokens is not None
  101. else None
  102. ),
  103. **sampling_kwargs,
  104. )[0]
  105. x = model.fast_embeddings(a)
  106. codebooks.append(a)
  107. return torch.stack(codebooks, dim=0)
  108. def decode_one_token_naive(
  109. model: NaiveTransformer,
  110. x: torch.Tensor,
  111. input_pos: torch.Tensor,
  112. previous_tokens: torch.Tensor = None,
  113. **sampling_kwargs,
  114. ) -> torch.Tensor:
  115. x = model.forward_generate(x, input_pos)
  116. codebooks = [
  117. sample(
  118. x.token_logits,
  119. previous_tokens=None, # Disable repetition penalty for the token codebook
  120. **sampling_kwargs,
  121. )[0]
  122. ]
  123. for i in range(model.config.num_codebooks):
  124. codebooks.append(
  125. sample(
  126. x.codebook_logits[:, :, i],
  127. previous_tokens=(
  128. previous_tokens[i + 1] if previous_tokens is not None else None
  129. ),
  130. **sampling_kwargs,
  131. )[0]
  132. )
  133. return torch.stack(codebooks, dim=0)
  134. def decode_n_tokens(
  135. model: NaiveTransformer,
  136. cur_token: torch.Tensor,
  137. input_pos: torch.Tensor,
  138. num_new_tokens: int,
  139. im_end_id: int = 4,
  140. decode_one_token=decode_one_token_naive,
  141. **sampling_kwargs,
  142. ):
  143. previous_tokens = torch.zeros(
  144. (model.config.num_codebooks + 1, model.config.max_seq_len),
  145. dtype=torch.int,
  146. device=cur_token.device,
  147. )
  148. for i in tqdm(range(num_new_tokens)):
  149. # We need to get windowed repeat penalty
  150. win_size = 16
  151. if i < win_size:
  152. window = previous_tokens[:, :win_size]
  153. else:
  154. window = previous_tokens[:, i - win_size : i]
  155. with torch.backends.cuda.sdp_kernel(
  156. enable_flash=False, enable_mem_efficient=False, enable_math=True
  157. ): # Actually better for Inductor to codegen attention here
  158. next_token = decode_one_token(
  159. model=model,
  160. x=cur_token,
  161. input_pos=input_pos,
  162. previous_tokens=window,
  163. **sampling_kwargs,
  164. )
  165. input_pos += 1
  166. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  167. previous_tokens[:, i : i + 1] = next_token.view(
  168. model.config.num_codebooks + 1, -1
  169. )
  170. if cur_token[0, 0, -1] == im_end_id:
  171. break
  172. return previous_tokens[:, : i + 1]
  173. @torch.no_grad()
  174. @torch.inference_mode()
  175. def generate(
  176. *,
  177. model: NaiveTransformer,
  178. prompt: torch.Tensor,
  179. max_new_tokens: int,
  180. im_end_id: int = 4,
  181. decode_one_token=decode_one_token_naive,
  182. **sampling_kwargs,
  183. ) -> torch.Tensor:
  184. """
  185. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  186. """
  187. # create an empty tensor of the expected final shape and fill in the current tokens
  188. T = prompt.size(1)
  189. if max_new_tokens:
  190. if T + max_new_tokens > model.config.max_seq_len:
  191. max_new_tokens = model.config.max_seq_len - T
  192. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  193. T_new = T + max_new_tokens
  194. else:
  195. T_new = model.config.max_seq_len
  196. max_new_tokens = T_new - T
  197. device, dtype = prompt.device, prompt.dtype
  198. with torch.device(device):
  199. model.setup_caches(
  200. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  201. )
  202. codebook_dim = 1 + model.config.num_codebooks
  203. # create an empty tensor of the expected final shape and fill in the current tokens
  204. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  205. empty[:, :T] = prompt
  206. seq = empty
  207. input_pos = torch.arange(0, T, device=device)
  208. # Use non-accelerated version for now, to avoid compilation overhead
  209. prefill_decode = (
  210. decode_one_token_naive
  211. if isinstance(model, NaiveTransformer)
  212. else decode_one_token_ar
  213. )
  214. next_token = prefill_decode(
  215. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  216. )
  217. seq[:, T : T + 1] = next_token
  218. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  219. x = decode_n_tokens(
  220. model,
  221. next_token.view(1, codebook_dim, -1),
  222. input_pos,
  223. max_new_tokens - 1,
  224. im_end_id=im_end_id,
  225. decode_one_token=decode_one_token,
  226. **sampling_kwargs,
  227. )
  228. # x = torch.cat(generated_tokens, dim=1)
  229. seq = seq[:, : T + 1 + x.size(1)]
  230. seq[:, T + 1 :] = x
  231. return seq
  232. def encode_tokens(
  233. tokenizer,
  234. string,
  235. device="cuda",
  236. prompt_tokens=None,
  237. num_codebooks=4,
  238. ):
  239. string = clean_text(string)
  240. string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
  241. new_tokens = tokenizer.encode(
  242. string,
  243. add_special_tokens=False,
  244. max_length=10**6,
  245. truncation=False,
  246. )
  247. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  248. # Codebooks
  249. zeros = (
  250. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  251. * CODEBOOK_PAD_TOKEN_ID
  252. )
  253. prompt = torch.cat((tokens, zeros), dim=0)
  254. if prompt_tokens is None:
  255. return prompt
  256. # Get prompt tokens
  257. if prompt_tokens.ndim == 3:
  258. assert (
  259. prompt_tokens.shape[0] == 1
  260. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  261. prompt_tokens = prompt_tokens[0]
  262. assert prompt_tokens.ndim == 2
  263. data = prompt_tokens + 1
  264. if prompt_tokens.shape[0] > num_codebooks:
  265. logger.warning(
  266. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  267. )
  268. data = data[:num_codebooks]
  269. # Add pad token for each codebook
  270. data = torch.cat(
  271. (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
  272. dim=1,
  273. )
  274. # Since 1.0, we use <|semantic|>
  275. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  276. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  277. main_token_ids = (
  278. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  279. )
  280. main_token_ids[0, -1] = end_token_id
  281. data = torch.cat((main_token_ids, data), dim=0)
  282. prompt = torch.cat((prompt, data), dim=1)
  283. return prompt
  284. def load_model(checkpoint_path, device, precision, compile=False):
  285. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  286. checkpoint_path, load_weights=True
  287. )
  288. model = model.to(device=device, dtype=precision)
  289. logger.info(f"Restored model from checkpoint")
  290. if isinstance(model, DualARTransformer):
  291. decode_one_token = decode_one_token_ar
  292. logger.info("Using DualARTransformer")
  293. else:
  294. decode_one_token = decode_one_token_naive
  295. logger.info("Using NaiveTransformer")
  296. if compile:
  297. logger.info("Compiling function...")
  298. decode_one_token = torch.compile(
  299. decode_one_token,
  300. fullgraph=True,
  301. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  302. mode="reduce-overhead" if torch.cuda.is_available() else None,
  303. )
  304. return model.eval(), decode_one_token
  305. @dataclass
  306. class GenerateResponse:
  307. action: Literal["sample", "next"]
  308. codes: Optional[torch.Tensor] = None
  309. text: Optional[str] = None
  310. def generate_long(
  311. *,
  312. model,
  313. device: str | torch.device,
  314. decode_one_token: callable,
  315. text: str,
  316. num_samples: int = 1,
  317. max_new_tokens: int = 0,
  318. top_p: int = 0.7,
  319. repetition_penalty: float = 1.5,
  320. temperature: float = 0.7,
  321. compile: bool = False,
  322. iterative_prompt: bool = True,
  323. max_length: int = 2048,
  324. chunk_length: int = 150,
  325. prompt_text: Optional[str | list[str]] = None,
  326. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  327. ):
  328. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  329. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  330. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  331. use_prompt = prompt_text is not None and prompt_tokens is not None
  332. if use_prompt and isinstance(prompt_text, str):
  333. prompt_text = [prompt_text]
  334. prompt_tokens = [prompt_tokens]
  335. assert use_prompt is False or len(prompt_text) == len(
  336. prompt_tokens
  337. ), "Prompt text and tokens must have the same length"
  338. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  339. tokenizer = model.tokenizer
  340. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  341. encoded = []
  342. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  343. encoded_prompts = []
  344. if use_prompt:
  345. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  346. encoded_prompts.append(
  347. encode_tokens(
  348. tokenizer,
  349. string=t,
  350. device=device,
  351. prompt_tokens=c,
  352. num_codebooks=model.config.num_codebooks,
  353. )
  354. )
  355. for idx, text in enumerate(texts):
  356. encoded.append(
  357. encode_tokens(
  358. tokenizer,
  359. string=text,
  360. device=device,
  361. num_codebooks=model.config.num_codebooks,
  362. )
  363. )
  364. logger.info(f"Encoded text: {text}")
  365. # Move temperature, top_p, repetition_penalty to device
  366. # This is important so that changing params doesn't trigger recompile
  367. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  368. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  369. repetition_penalty = torch.tensor(
  370. repetition_penalty, device=device, dtype=torch.float
  371. )
  372. for sample_idx in range(num_samples):
  373. if torch.cuda.is_available():
  374. torch.cuda.synchronize()
  375. global_encoded = []
  376. seg_idx = 0
  377. while seg_idx < len(encoded):
  378. logger.info(
  379. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  380. )
  381. seg = encoded[seg_idx]
  382. global_encoded.append(seg)
  383. lengths = reversed([seg.size(1) for seg in global_encoded])
  384. # Pick last 2000 tokens
  385. count = 0
  386. for i, length in enumerate(lengths):
  387. count += length
  388. if count + length > max_length - 1024 - sum(
  389. t.shape[1] for t in encoded_prompts
  390. ):
  391. break
  392. if i != 0 and i % 2 == 0:
  393. i -= 1
  394. # Rotate the list, always make sure first segment is included to avoid drift
  395. if i < len(global_encoded) - 2:
  396. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  397. else:
  398. partial_encoded = global_encoded
  399. if use_prompt:
  400. partial_encoded = encoded_prompts + partial_encoded
  401. cat_encoded = torch.cat(partial_encoded, dim=1)
  402. prompt_length = cat_encoded.size(1)
  403. t0 = time.perf_counter()
  404. y = generate(
  405. model=model,
  406. prompt=cat_encoded,
  407. max_new_tokens=max_new_tokens,
  408. im_end_id=im_end_id,
  409. decode_one_token=decode_one_token,
  410. temperature=temperature,
  411. top_p=top_p,
  412. repetition_penalty=repetition_penalty,
  413. )
  414. if sample_idx == 0 and seg_idx == 0 and compile:
  415. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  416. if torch.cuda.is_available():
  417. torch.cuda.synchronize()
  418. t = time.perf_counter() - t0
  419. tokens_generated = y.size(1) - prompt_length
  420. tokens_sec = tokens_generated / t
  421. logger.info(
  422. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  423. )
  424. logger.info(
  425. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  426. )
  427. if torch.cuda.is_available():
  428. logger.info(
  429. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  430. )
  431. # Put the generated tokens
  432. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  433. codes = y[1:, prompt_length:-1].clone()
  434. codes = codes - 1
  435. assert (codes >= 0).all(), f"Negative code found"
  436. decoded = y[:, prompt_length:-1].clone()
  437. # But for global encoding, we should keep the <im_end> token
  438. global_encoded.append(decoded)
  439. assert (codes >= 0).all(), f"Negative code found: {codes}"
  440. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  441. seg_idx += 1
  442. # This indicates the end of the current sample
  443. yield GenerateResponse(action="next")
  444. @dataclass
  445. class WrappedGenerateResponse:
  446. status: Literal["success", "error"]
  447. response: Optional[GenerateResponse | Exception] = None
  448. @dataclass
  449. class GenerateRequest:
  450. request: dict
  451. response_queue: queue.Queue
  452. def launch_thread_safe_queue(
  453. checkpoint_path,
  454. device,
  455. precision,
  456. compile: bool = False,
  457. ):
  458. input_queue = queue.Queue()
  459. init_event = threading.Event()
  460. def worker():
  461. model, decode_one_token = load_model(
  462. checkpoint_path, device, precision, compile=compile
  463. )
  464. init_event.set()
  465. while True:
  466. item: GenerateRequest | None = input_queue.get()
  467. if item is None:
  468. break
  469. kwargs = item.request
  470. response_queue = item.response_queue
  471. try:
  472. for chunk in generate_long(
  473. model=model, decode_one_token=decode_one_token, **kwargs
  474. ):
  475. response_queue.put(
  476. WrappedGenerateResponse(status="success", response=chunk)
  477. )
  478. except Exception as e:
  479. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  480. threading.Thread(target=worker, daemon=True).start()
  481. init_event.wait()
  482. return input_queue
  483. @click.command()
  484. @click.option(
  485. "--text",
  486. type=str,
  487. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  488. )
  489. @click.option("--prompt-text", type=str, default=None, multiple=True)
  490. @click.option(
  491. "--prompt-tokens",
  492. type=click.Path(path_type=Path, exists=True),
  493. default=None,
  494. multiple=True,
  495. )
  496. @click.option("--num-samples", type=int, default=1)
  497. @click.option("--max-new-tokens", type=int, default=0)
  498. @click.option("--top-p", type=float, default=0.7)
  499. @click.option("--repetition-penalty", type=float, default=1.2)
  500. @click.option("--temperature", type=float, default=0.7)
  501. @click.option(
  502. "--checkpoint-path",
  503. type=click.Path(path_type=Path, exists=True),
  504. default="checkpoints/fish-speech-1.2-sft",
  505. )
  506. @click.option("--device", type=str, default="cuda")
  507. @click.option("--compile/--no-compile", default=False)
  508. @click.option("--seed", type=int, default=42)
  509. @click.option("--half/--no-half", default=False)
  510. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  511. @click.option("--chunk-length", type=int, default=100)
  512. def main(
  513. text: str,
  514. prompt_text: Optional[list[str]],
  515. prompt_tokens: Optional[list[Path]],
  516. num_samples: int,
  517. max_new_tokens: int,
  518. top_p: int,
  519. repetition_penalty: float,
  520. temperature: float,
  521. checkpoint_path: Path,
  522. device: str,
  523. compile: bool,
  524. seed: int,
  525. half: bool,
  526. iterative_prompt: bool,
  527. chunk_length: int,
  528. ) -> None:
  529. precision = torch.half if half else torch.bfloat16
  530. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  531. raise ValueError(
  532. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  533. )
  534. logger.info("Loading model ...")
  535. t0 = time.time()
  536. model, decode_one_token = load_model(
  537. checkpoint_path, device, precision, compile=compile
  538. )
  539. if torch.cuda.is_available():
  540. torch.cuda.synchronize()
  541. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  542. if prompt_tokens is not None:
  543. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  544. torch.manual_seed(seed)
  545. if torch.cuda.is_available():
  546. torch.cuda.manual_seed(seed)
  547. generator = generate_long(
  548. model=model,
  549. device=device,
  550. decode_one_token=decode_one_token,
  551. text=text,
  552. num_samples=num_samples,
  553. max_new_tokens=max_new_tokens,
  554. top_p=top_p,
  555. repetition_penalty=repetition_penalty,
  556. temperature=temperature,
  557. compile=compile,
  558. iterative_prompt=iterative_prompt,
  559. chunk_length=chunk_length,
  560. prompt_text=prompt_text,
  561. prompt_tokens=prompt_tokens,
  562. )
  563. idx = 0
  564. codes = []
  565. for response in generator:
  566. if response.action == "sample":
  567. codes.append(response.codes)
  568. logger.info(f"Sampled text: {response.text}")
  569. elif response.action == "next":
  570. if codes:
  571. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  572. logger.info(f"Saved codes to codes_{idx}.npy")
  573. logger.info(f"Next sample")
  574. codes = []
  575. idx += 1
  576. else:
  577. logger.error(f"Error: {response}")
  578. if __name__ == "__main__":
  579. main()