generate.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. import os
  2. import time
  3. from pathlib import Path
  4. from typing import Optional, Tuple, Union
  5. import click
  6. import numpy as np
  7. import torch
  8. import torch._dynamo.config
  9. import torch._inductor.config
  10. from hydra import compose, initialize
  11. from hydra.utils import instantiate
  12. from loguru import logger
  13. from tqdm import tqdm
  14. from transformers import AutoTokenizer
  15. from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID
  16. from fish_speech.text.clean import clean_text
  17. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  18. torch._inductor.config.coordinate_descent_tuning = True
  19. torch._inductor.config.triton.unique_kernel_names = True
  20. if hasattr(torch._inductor.config, "fx_graph_cache"):
  21. # Experimental feature to reduce compilation times, will be on by default in future
  22. torch._inductor.config.fx_graph_cache = True
  23. from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
  24. def multinomial_sample_one_no_sync(
  25. probs_sort,
  26. ): # Does multinomial sampling without a cuda synchronization
  27. q = torch.empty_like(probs_sort).exponential_(1)
  28. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  29. def logits_to_probs(
  30. logits,
  31. previous_tokens: Optional[torch.Tensor] = None,
  32. temperature: float = 1.0,
  33. top_k: Optional[int] = None,
  34. top_p: Optional[int] = None,
  35. repetition_penalty: float = 1.0,
  36. ):
  37. if previous_tokens is not None and repetition_penalty != 1.0:
  38. previous_tokens = previous_tokens.long()
  39. score = torch.gather(logits, dim=0, index=previous_tokens)
  40. score = torch.where(
  41. score < 0, score * repetition_penalty, score / repetition_penalty
  42. )
  43. logits.scatter_(dim=0, index=previous_tokens, src=score)
  44. if top_p is not None and top_p < 1.0:
  45. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  46. cum_probs = torch.cumsum(
  47. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  48. )
  49. sorted_indices_to_remove = cum_probs > top_p
  50. sorted_indices_to_remove[0] = False # keep at least one option
  51. indices_to_remove = sorted_indices_to_remove.scatter(
  52. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  53. )
  54. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  55. logits = logits / max(temperature, 1e-5)
  56. if top_k is not None:
  57. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  58. pivot = v.select(-1, -1).unsqueeze(-1)
  59. logits = torch.where(logits < pivot, -float("Inf"), logits)
  60. probs = torch.nn.functional.softmax(logits, dim=-1)
  61. return probs
  62. def sample(
  63. logits,
  64. previous_tokens: Optional[torch.Tensor] = None,
  65. **sampling_kwargs,
  66. ) -> Tuple[torch.Tensor, torch.Tensor]:
  67. probs = logits_to_probs(
  68. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  69. )
  70. idx_next = multinomial_sample_one_no_sync(probs)
  71. return idx_next, probs
  72. def decode_one_token_ar(
  73. model: DualARTransformer,
  74. x: torch.Tensor,
  75. input_pos: torch.Tensor,
  76. previous_tokens: torch.Tensor = None,
  77. **sampling_kwargs,
  78. ) -> torch.Tensor:
  79. x = model.forward_generate(x, input_pos)
  80. codebooks = [
  81. sample(
  82. x.logits,
  83. previous_tokens=None, # Disable repetition penalty for the token codebook
  84. **sampling_kwargs,
  85. )[0]
  86. ]
  87. x = x.hidden_states
  88. # Cleanup the cache
  89. for layer in model.fast_layers:
  90. layer.attention.kv_cache.k_cache.fill_(0)
  91. layer.attention.kv_cache.v_cache.fill_(0)
  92. for codebook_idx in range(model.config.num_codebooks):
  93. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  94. logits = model.forward_generate_fast(x, input_pos)
  95. a = sample(
  96. logits,
  97. previous_tokens=(
  98. previous_tokens[codebook_idx + 1]
  99. if previous_tokens is not None
  100. else None
  101. ),
  102. **sampling_kwargs,
  103. )[0]
  104. x = model.fast_embeddings(a)
  105. codebooks.append(a)
  106. return torch.stack(codebooks, dim=0)
  107. def decode_one_token_naive(
  108. model: NaiveTransformer,
  109. x: torch.Tensor,
  110. input_pos: torch.Tensor,
  111. previous_tokens: torch.Tensor = None,
  112. **sampling_kwargs,
  113. ) -> torch.Tensor:
  114. x = model.forward_generate(x, input_pos)
  115. codebooks = [
  116. sample(
  117. x.token_logits,
  118. previous_tokens=None, # Disable repetition penalty for the token codebook
  119. **sampling_kwargs,
  120. )[0]
  121. ]
  122. for i in range(model.config.num_codebooks):
  123. codebooks.append(
  124. sample(
  125. x.codebook_logits[:, :, i],
  126. previous_tokens=(
  127. previous_tokens[i + 1] if previous_tokens is not None else None
  128. ),
  129. **sampling_kwargs,
  130. )[0]
  131. )
  132. return torch.stack(codebooks, dim=0)
  133. def decode_n_tokens(
  134. model: NaiveTransformer,
  135. cur_token: torch.Tensor,
  136. input_pos: torch.Tensor,
  137. num_new_tokens: int,
  138. eos_token_id: int = 2,
  139. im_end_id: int = 4,
  140. decode_one_token=decode_one_token_naive,
  141. **sampling_kwargs,
  142. ):
  143. previous_tokens = torch.zeros(
  144. (model.config.num_codebooks + 1, model.config.max_seq_len),
  145. dtype=torch.int,
  146. device=cur_token.device,
  147. )
  148. for i in tqdm(range(num_new_tokens)):
  149. # We need to get windowed repeat penalty
  150. win_size = 16
  151. if i < win_size:
  152. window = previous_tokens[:, :win_size]
  153. else:
  154. window = previous_tokens[:, i - win_size : i]
  155. with torch.backends.cuda.sdp_kernel(
  156. enable_flash=False, enable_mem_efficient=False, enable_math=True
  157. ): # Actually better for Inductor to codegen attention here
  158. next_token = decode_one_token(
  159. model=model,
  160. x=cur_token,
  161. input_pos=input_pos,
  162. previous_tokens=window,
  163. **sampling_kwargs,
  164. )
  165. input_pos += 1
  166. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  167. previous_tokens[:, i : i + 1] = next_token.view(
  168. model.config.num_codebooks + 1, -1
  169. )
  170. if (
  171. cur_token[0, 0, -1] == eos_token_id
  172. or cur_token[0, 0, -1] == im_end_id
  173. or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
  174. ):
  175. break
  176. return previous_tokens[:, : i + 1]
  177. @torch.no_grad()
  178. @torch.inference_mode()
  179. def generate(
  180. *,
  181. model: NaiveTransformer,
  182. prompt: torch.Tensor,
  183. max_new_tokens: int,
  184. eos_token_id: int = 2,
  185. im_end_id: int = 4,
  186. decode_one_token=decode_one_token_naive,
  187. **sampling_kwargs,
  188. ) -> torch.Tensor:
  189. """
  190. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  191. """
  192. # create an empty tensor of the expected final shape and fill in the current tokens
  193. T = prompt.size(1)
  194. if max_new_tokens:
  195. if T + max_new_tokens > model.config.max_seq_len:
  196. max_new_tokens = model.config.max_seq_len - T
  197. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  198. T_new = T + max_new_tokens
  199. else:
  200. T_new = model.config.max_seq_len
  201. max_new_tokens = T_new - T
  202. device, dtype = prompt.device, prompt.dtype
  203. with torch.device(device):
  204. model.setup_caches(
  205. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  206. )
  207. codebook_dim = 1 + model.config.num_codebooks
  208. # create an empty tensor of the expected final shape and fill in the current tokens
  209. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  210. empty[:, :T] = prompt
  211. seq = empty
  212. input_pos = torch.arange(0, T, device=device)
  213. # Use non-accelerated version for now, to avoid compilation overhead
  214. prefill_decode = (
  215. decode_one_token_naive
  216. if isinstance(model, NaiveTransformer)
  217. else decode_one_token_ar
  218. )
  219. next_token = prefill_decode(
  220. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  221. )
  222. seq[:, T : T + 1] = next_token
  223. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  224. x = decode_n_tokens(
  225. model,
  226. next_token.view(1, codebook_dim, -1),
  227. input_pos,
  228. max_new_tokens - 1,
  229. eos_token_id=eos_token_id,
  230. im_end_id=im_end_id,
  231. decode_one_token=decode_one_token,
  232. **sampling_kwargs,
  233. )
  234. # x = torch.cat(generated_tokens, dim=1)
  235. seq = seq[:, : T + 1 + x.size(1)]
  236. seq[:, T + 1 :] = x
  237. return seq
  238. def encode_tokens(
  239. tokenizer,
  240. string,
  241. bos=True,
  242. device="cuda",
  243. prompt_tokens=None,
  244. speaker=None,
  245. num_codebooks=4,
  246. ):
  247. string = clean_text(string)
  248. if speaker is not None:
  249. string = f"[SPK: {speaker}] {string}"
  250. string = (
  251. f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>assistant<|im_sep|>"
  252. )
  253. if bos:
  254. string = f"<|begin_of_sequence|>{string}"
  255. new_tokens = tokenizer.encode(
  256. string,
  257. add_special_tokens=False,
  258. max_length=10**6,
  259. truncation=False,
  260. )
  261. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  262. # Codebooks
  263. zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  264. prompt = torch.cat((tokens, zeros), dim=0)
  265. if prompt_tokens is None:
  266. return prompt
  267. # Get prompt tokens
  268. if prompt_tokens.ndim == 3:
  269. assert (
  270. prompt_tokens.shape[0] == 1
  271. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  272. prompt_tokens = prompt_tokens[0]
  273. assert prompt_tokens.ndim == 2
  274. data = prompt_tokens + 2
  275. if prompt_tokens.shape[0] > num_codebooks:
  276. logger.warning(
  277. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  278. )
  279. data = data[:num_codebooks]
  280. # Since 1.0, we use <|semantic|>
  281. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  282. main_token_ids = torch.tensor(
  283. [[s0_token_id] * data.size(1)],
  284. dtype=torch.int,
  285. device=device,
  286. )
  287. data = torch.cat((main_token_ids, data), dim=0)
  288. prompt = torch.cat((prompt, data), dim=1)
  289. return prompt
  290. def load_model(
  291. config_name, checkpoint_path, device, precision, max_length, compile=False
  292. ):
  293. with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
  294. cfg = compose(
  295. config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
  296. )
  297. model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
  298. if "int8" in str(checkpoint_path):
  299. logger.info("Using int8 weight-only quantization!")
  300. from quantize import WeightOnlyInt8QuantHandler
  301. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  302. model = simple_quantizer.convert_for_runtime()
  303. if "int4" in str(checkpoint_path):
  304. logger.info("Using int4 quantization!")
  305. path_comps = checkpoint_path.name.split(".")
  306. assert path_comps[-2].startswith("g")
  307. groupsize = int(path_comps[-2][1:])
  308. from quantize import WeightOnlyInt4QuantHandler
  309. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  310. model = simple_quantizer.convert_for_runtime()
  311. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  312. if "state_dict" in checkpoint:
  313. checkpoint = checkpoint["state_dict"]
  314. if any(k.startswith("model.") for k in checkpoint):
  315. checkpoint = {
  316. k.replace("model.", ""): v
  317. for k, v in checkpoint.items()
  318. if k.startswith("model.")
  319. }
  320. model.load_state_dict(checkpoint, assign=True)
  321. model = model.to(device=device, dtype=precision)
  322. logger.info("Restored model from checkpoint")
  323. if isinstance(model, DualARTransformer):
  324. decode_one_token = decode_one_token_ar
  325. logger.info("Using DualARTransformer")
  326. else:
  327. decode_one_token = decode_one_token_naive
  328. logger.info("Using NaiveTransformer")
  329. if compile:
  330. logger.info("Compiling function...")
  331. decode_one_token = torch.compile(
  332. decode_one_token, mode="reduce-overhead", fullgraph=True
  333. )
  334. return model.eval(), decode_one_token
  335. def split_text(text, min_length):
  336. text = clean_text(text)
  337. segments = []
  338. curr = ""
  339. for char in text:
  340. curr += char
  341. if char not in [".", ",", "!", "?"]:
  342. continue
  343. if len(curr) >= min_length:
  344. segments.append(curr)
  345. curr = ""
  346. if curr:
  347. segments.append(curr)
  348. return segments
  349. def generate_long(
  350. *,
  351. model,
  352. tokenizer: callable,
  353. device: str | torch.device,
  354. decode_one_token: callable,
  355. text: str,
  356. num_samples: int = 1,
  357. max_new_tokens: int = 0,
  358. top_k: int = None,
  359. top_p: int = 0.7,
  360. repetition_penalty: float = 1.5,
  361. temperature: float = 0.7,
  362. compile: bool = False,
  363. iterative_prompt: bool = True,
  364. max_length: int = 2048,
  365. chunk_length: int = 30,
  366. speaker: Optional[str] = None,
  367. prompt_text: Optional[str] = None,
  368. prompt_tokens: Optional[torch.Tensor] = None,
  369. ):
  370. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  371. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  372. use_prompt = prompt_text is not None and prompt_tokens is not None
  373. encoded = []
  374. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  375. for idx, text in enumerate(texts):
  376. encoded.append(
  377. encode_tokens(
  378. tokenizer,
  379. string=text,
  380. bos=idx == 0 and not use_prompt,
  381. device=device,
  382. speaker=None,
  383. num_codebooks=model.config.num_codebooks,
  384. )
  385. )
  386. logger.info(f"Encoded text: {text}")
  387. if use_prompt:
  388. encoded_prompt = encode_tokens(
  389. tokenizer,
  390. prompt_text,
  391. prompt_tokens=prompt_tokens,
  392. bos=True,
  393. device=device,
  394. speaker=speaker,
  395. num_codebooks=model.config.num_codebooks,
  396. )
  397. encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
  398. for sample_idx in range(num_samples):
  399. torch.cuda.synchronize()
  400. global_encoded = []
  401. all_codes = []
  402. seg_idx = 0
  403. while seg_idx < len(encoded):
  404. logger.info(
  405. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  406. )
  407. seg = encoded[seg_idx]
  408. global_encoded.append(seg)
  409. lengths = reversed([seg.size(1) for seg in global_encoded])
  410. # Pick last 2000 tokens
  411. count = 0
  412. for i, length in enumerate(lengths):
  413. count += length
  414. if count + length > max_length - 1024:
  415. break
  416. if i != 0 and i % 2 == 0:
  417. i -= 1
  418. # Rotate the list, always make sure first segment is included to avoid drift
  419. if i < len(global_encoded) - 2:
  420. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  421. else:
  422. partial_encoded = global_encoded
  423. cat_encoded = torch.cat(partial_encoded, dim=1)
  424. prompt_length = cat_encoded.size(1)
  425. t0 = time.perf_counter()
  426. y = generate(
  427. model=model,
  428. prompt=cat_encoded,
  429. max_new_tokens=max_new_tokens,
  430. eos_token_id=tokenizer.eos_token_id,
  431. im_end_id=im_end_id,
  432. decode_one_token=decode_one_token,
  433. temperature=temperature,
  434. top_k=top_k,
  435. top_p=top_p,
  436. repetition_penalty=repetition_penalty,
  437. )
  438. if sample_idx == 0 and seg_idx == 0 and compile:
  439. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  440. torch.cuda.synchronize()
  441. t = time.perf_counter() - t0
  442. tokens_generated = y.size(1) - prompt_length
  443. tokens_sec = tokens_generated / t
  444. logger.info(
  445. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  446. )
  447. logger.info(
  448. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  449. )
  450. logger.info(
  451. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  452. )
  453. # Put the generated tokens
  454. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  455. codes = y[1:, prompt_length:-2].clone()
  456. codes = codes - 2
  457. if not (codes >= 0).all():
  458. global_encoded.pop()
  459. logger.warning(f"Negative code found: {codes}, retrying ...")
  460. continue
  461. decoded = y[:, prompt_length:-1].clone()
  462. if decoded[0, -1] != im_end_id: # <im_end>
  463. val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
  464. decoded = torch.cat(
  465. (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
  466. )
  467. # But for global encoding, we should keep the <im_end> token
  468. global_encoded.append(decoded)
  469. all_codes.append(codes)
  470. seg_idx += 1
  471. codes = torch.cat(all_codes, dim=1)
  472. assert (codes >= 0).all(), f"Negative code found: {codes}"
  473. yield codes
  474. @click.command()
  475. @click.option(
  476. "--text",
  477. type=str,
  478. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  479. )
  480. @click.option("--prompt-text", type=str, default=None)
  481. @click.option(
  482. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  483. )
  484. @click.option("--num-samples", type=int, default=1)
  485. @click.option("--max-new-tokens", type=int, default=0)
  486. @click.option("--top-k", type=int, default=None)
  487. @click.option("--top-p", type=float, default=0.7)
  488. @click.option("--repetition-penalty", type=float, default=1.5)
  489. @click.option("--temperature", type=float, default=0.7)
  490. @click.option(
  491. "--checkpoint-path",
  492. type=click.Path(path_type=Path, exists=True),
  493. default="results/text2semantic_400m_finetune/step_000002000.pth",
  494. )
  495. @click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
  496. @click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
  497. @click.option("--compile/--no-compile", default=False)
  498. @click.option("--seed", type=int, default=42)
  499. @click.option("--speaker", type=str, default=None)
  500. @click.option("--half/--no-half", default=False)
  501. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  502. @click.option("--max-length", type=int, default=2048)
  503. @click.option("--chunk-length", type=int, default=30)
  504. def main(
  505. text: str,
  506. prompt_text: Optional[str],
  507. prompt_tokens: Optional[Path],
  508. num_samples: int,
  509. max_new_tokens: int,
  510. top_k: int,
  511. top_p: int,
  512. repetition_penalty: float,
  513. temperature: float,
  514. checkpoint_path: Path,
  515. config_name: str,
  516. tokenizer: str,
  517. compile: bool,
  518. seed: int,
  519. speaker: Optional[str],
  520. half: bool,
  521. iterative_prompt: bool,
  522. max_length: int,
  523. chunk_length: int,
  524. ) -> None:
  525. device = "cuda"
  526. precision = torch.half if half else torch.bfloat16
  527. logger.info("Loading model ...")
  528. t0 = time.time()
  529. model, decode_one_token = load_model(
  530. config_name, checkpoint_path, device, precision, max_length, compile=compile
  531. )
  532. torch.cuda.synchronize()
  533. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  534. prompt_tokens = (
  535. torch.from_numpy(np.load(prompt_tokens)).to(device)
  536. if prompt_tokens is not None
  537. else None
  538. )
  539. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  540. torch.manual_seed(seed)
  541. torch.cuda.manual_seed(seed)
  542. generator = generate_long(
  543. model=model,
  544. device=device,
  545. decode_one_token=decode_one_token,
  546. text=text,
  547. num_samples=num_samples,
  548. max_new_tokens=max_new_tokens,
  549. top_k=top_k,
  550. top_p=top_p,
  551. repetition_penalty=repetition_penalty,
  552. temperature=temperature,
  553. tokenizer=tokenizer,
  554. compile=compile,
  555. speaker=speaker,
  556. iterative_prompt=iterative_prompt,
  557. max_length=max_length,
  558. chunk_length=chunk_length,
  559. prompt_text=prompt_text,
  560. prompt_tokens=prompt_tokens,
  561. )
  562. for idx, codes in enumerate(generator):
  563. np.save(f"codes_{idx}.npy", codes.cpu().numpy())
  564. logger.info(f"Saved codes to codes_{idx}.npy")
  565. if __name__ == "__main__":
  566. main()