generate.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from typing import Literal, Optional, Tuple, Union
  8. import click
  9. import hydra
  10. import numpy as np
  11. import torch
  12. import torch._dynamo.config
  13. import torch._inductor.config
  14. from loguru import logger
  15. from tqdm import tqdm
  16. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  17. from fish_speech.text import clean_text, split_text
  18. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  19. torch._inductor.config.coordinate_descent_tuning = True
  20. torch._inductor.config.triton.unique_kernel_names = True
  21. if hasattr(torch._inductor.config, "fx_graph_cache"):
  22. # Experimental feature to reduce compilation times, will be on by default in future
  23. torch._inductor.config.fx_graph_cache = True
  24. from fish_speech.models.text2semantic.llama import (
  25. BaseTransformer,
  26. DualARTransformer,
  27. NaiveTransformer,
  28. )
  29. def multinomial_sample_one_no_sync(
  30. probs_sort,
  31. ): # Does multinomial sampling without a cuda synchronization
  32. q = torch.empty_like(probs_sort).exponential_(1)
  33. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  34. def logits_to_probs(
  35. logits,
  36. previous_tokens: Optional[torch.Tensor] = None,
  37. temperature: torch.Tensor = 1.0,
  38. top_p: torch.Tensor = 1.0,
  39. repetition_penalty: torch.Tensor = 1.0,
  40. ) -> torch.Tensor:
  41. # Apply repetition penalty
  42. if previous_tokens is not None:
  43. previous_tokens = previous_tokens.long()
  44. score = torch.gather(logits, dim=0, index=previous_tokens)
  45. score = torch.where(
  46. score < 0, score * repetition_penalty, score / repetition_penalty
  47. )
  48. logits.scatter_(dim=0, index=previous_tokens, src=score)
  49. # Apply top-p sampling
  50. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  51. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  52. sorted_indices_to_remove = cum_probs > top_p
  53. sorted_indices_to_remove[0] = False # keep at least one option
  54. indices_to_remove = sorted_indices_to_remove.scatter(
  55. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  56. )
  57. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  58. logits = logits / max(temperature, 1e-5)
  59. probs = torch.nn.functional.softmax(logits, dim=-1)
  60. return probs
  61. def sample(
  62. logits,
  63. previous_tokens: Optional[torch.Tensor] = None,
  64. **sampling_kwargs,
  65. ) -> Tuple[torch.Tensor, torch.Tensor]:
  66. probs = logits_to_probs(
  67. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  68. )
  69. idx_next = multinomial_sample_one_no_sync(probs)
  70. return idx_next, probs
  71. def decode_one_token_ar(
  72. model: DualARTransformer,
  73. x: torch.Tensor,
  74. input_pos: torch.Tensor,
  75. previous_tokens: torch.Tensor = None,
  76. **sampling_kwargs,
  77. ) -> torch.Tensor:
  78. x = model.forward_generate(x, input_pos)
  79. codebooks = [
  80. sample(
  81. x.logits,
  82. previous_tokens=None, # Disable repetition penalty for the token codebook
  83. **sampling_kwargs,
  84. )[0]
  85. ]
  86. x = x.hidden_states
  87. # Cleanup the cache
  88. for layer in model.fast_layers:
  89. layer.attention.kv_cache.k_cache.fill_(0)
  90. layer.attention.kv_cache.v_cache.fill_(0)
  91. for codebook_idx in range(model.config.num_codebooks):
  92. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  93. logits = model.forward_generate_fast(x, input_pos)
  94. a = sample(
  95. logits,
  96. previous_tokens=(
  97. previous_tokens[codebook_idx + 1]
  98. if previous_tokens is not None
  99. else None
  100. ),
  101. **sampling_kwargs,
  102. )[0]
  103. x = model.fast_embeddings(a)
  104. codebooks.append(a)
  105. return torch.stack(codebooks, dim=0)
  106. def decode_one_token_naive(
  107. model: NaiveTransformer,
  108. x: torch.Tensor,
  109. input_pos: torch.Tensor,
  110. previous_tokens: torch.Tensor = None,
  111. **sampling_kwargs,
  112. ) -> torch.Tensor:
  113. x = model.forward_generate(x, input_pos)
  114. codebooks = [
  115. sample(
  116. x.token_logits,
  117. previous_tokens=None, # Disable repetition penalty for the token codebook
  118. **sampling_kwargs,
  119. )[0]
  120. ]
  121. for i in range(model.config.num_codebooks):
  122. codebooks.append(
  123. sample(
  124. x.codebook_logits[:, :, i],
  125. previous_tokens=(
  126. previous_tokens[i + 1] if previous_tokens is not None else None
  127. ),
  128. **sampling_kwargs,
  129. )[0]
  130. )
  131. return torch.stack(codebooks, dim=0)
  132. def decode_n_tokens(
  133. model: NaiveTransformer,
  134. cur_token: torch.Tensor,
  135. input_pos: torch.Tensor,
  136. num_new_tokens: int,
  137. im_end_id: int = 4,
  138. decode_one_token=decode_one_token_naive,
  139. **sampling_kwargs,
  140. ):
  141. previous_tokens = torch.zeros(
  142. (model.config.num_codebooks + 1, model.config.max_seq_len),
  143. dtype=torch.int,
  144. device=cur_token.device,
  145. )
  146. for i in tqdm(range(num_new_tokens)):
  147. # We need to get windowed repeat penalty
  148. win_size = 16
  149. if i < win_size:
  150. window = previous_tokens[:, :win_size]
  151. else:
  152. window = previous_tokens[:, i - win_size : i]
  153. with torch.backends.cuda.sdp_kernel(
  154. enable_flash=False, enable_mem_efficient=False, enable_math=True
  155. ): # Actually better for Inductor to codegen attention here
  156. next_token = decode_one_token(
  157. model=model,
  158. x=cur_token,
  159. input_pos=input_pos,
  160. previous_tokens=window,
  161. **sampling_kwargs,
  162. )
  163. input_pos += 1
  164. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  165. previous_tokens[:, i : i + 1] = next_token.view(
  166. model.config.num_codebooks + 1, -1
  167. )
  168. if cur_token[0, 0, -1] == im_end_id:
  169. break
  170. return previous_tokens[:, : i + 1]
  171. @torch.no_grad()
  172. @torch.inference_mode()
  173. def generate(
  174. *,
  175. model: NaiveTransformer,
  176. prompt: torch.Tensor,
  177. max_new_tokens: int,
  178. im_end_id: int = 4,
  179. decode_one_token=decode_one_token_naive,
  180. **sampling_kwargs,
  181. ) -> torch.Tensor:
  182. """
  183. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  184. """
  185. # create an empty tensor of the expected final shape and fill in the current tokens
  186. T = prompt.size(1)
  187. if max_new_tokens:
  188. if T + max_new_tokens > model.config.max_seq_len:
  189. max_new_tokens = model.config.max_seq_len - T
  190. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  191. T_new = T + max_new_tokens
  192. else:
  193. T_new = model.config.max_seq_len
  194. max_new_tokens = T_new - T
  195. device, dtype = prompt.device, prompt.dtype
  196. with torch.device(device):
  197. model.setup_caches(
  198. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  199. )
  200. codebook_dim = 1 + model.config.num_codebooks
  201. # create an empty tensor of the expected final shape and fill in the current tokens
  202. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  203. empty[:, :T] = prompt
  204. seq = empty
  205. input_pos = torch.arange(0, T, device=device)
  206. # Use non-accelerated version for now, to avoid compilation overhead
  207. prefill_decode = (
  208. decode_one_token_naive
  209. if isinstance(model, NaiveTransformer)
  210. else decode_one_token_ar
  211. )
  212. next_token = prefill_decode(
  213. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  214. )
  215. seq[:, T : T + 1] = next_token
  216. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  217. x = decode_n_tokens(
  218. model,
  219. next_token.view(1, codebook_dim, -1),
  220. input_pos,
  221. max_new_tokens - 1,
  222. im_end_id=im_end_id,
  223. decode_one_token=decode_one_token,
  224. **sampling_kwargs,
  225. )
  226. # x = torch.cat(generated_tokens, dim=1)
  227. seq = seq[:, : T + 1 + x.size(1)]
  228. seq[:, T + 1 :] = x
  229. return seq
  230. def encode_tokens(
  231. tokenizer,
  232. string,
  233. device="cuda",
  234. prompt_tokens=None,
  235. num_codebooks=4,
  236. ):
  237. string = clean_text(string)
  238. string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
  239. new_tokens = tokenizer.encode(
  240. string,
  241. add_special_tokens=False,
  242. max_length=10**6,
  243. truncation=False,
  244. )
  245. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  246. # Codebooks
  247. zeros = (
  248. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  249. * CODEBOOK_PAD_TOKEN_ID
  250. )
  251. prompt = torch.cat((tokens, zeros), dim=0)
  252. if prompt_tokens is None:
  253. return prompt
  254. # Get prompt tokens
  255. if prompt_tokens.ndim == 3:
  256. assert (
  257. prompt_tokens.shape[0] == 1
  258. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  259. prompt_tokens = prompt_tokens[0]
  260. assert prompt_tokens.ndim == 2
  261. data = prompt_tokens + 1
  262. if prompt_tokens.shape[0] > num_codebooks:
  263. logger.warning(
  264. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  265. )
  266. data = data[:num_codebooks]
  267. # Add pad token for each codebook
  268. data = torch.cat(
  269. (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
  270. dim=1,
  271. )
  272. # Since 1.0, we use <|semantic|>
  273. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  274. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  275. main_token_ids = (
  276. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  277. )
  278. main_token_ids[0, -1] = end_token_id
  279. data = torch.cat((main_token_ids, data), dim=0)
  280. prompt = torch.cat((prompt, data), dim=1)
  281. return prompt
  282. def load_model(checkpoint_path, device, precision, compile=False):
  283. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  284. checkpoint_path, load_weights=True
  285. )
  286. if "int8" in str(checkpoint_path):
  287. logger.info("Using int8 weight-only quantization!")
  288. from .quantize import WeightOnlyInt8QuantHandler
  289. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  290. model = simple_quantizer.convert_for_runtime()
  291. if "int4" in str(checkpoint_path):
  292. logger.info("Using int4 quantization!")
  293. path_comps = checkpoint_path.name.split(".")
  294. assert path_comps[-2].startswith("g")
  295. groupsize = int(path_comps[-2][1:])
  296. from .quantize import WeightOnlyInt4QuantHandler
  297. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  298. model = simple_quantizer.convert_for_runtime()
  299. model = model.to(device=device, dtype=precision)
  300. logger.info(f"Restored model from checkpoint")
  301. if isinstance(model, DualARTransformer):
  302. decode_one_token = decode_one_token_ar
  303. logger.info("Using DualARTransformer")
  304. else:
  305. decode_one_token = decode_one_token_naive
  306. logger.info("Using NaiveTransformer")
  307. if compile:
  308. logger.info("Compiling function...")
  309. decode_one_token = torch.compile(
  310. decode_one_token, mode="reduce-overhead", fullgraph=True
  311. )
  312. return model.eval(), decode_one_token
  313. @dataclass
  314. class GenerateResponse:
  315. action: Literal["sample", "next"]
  316. codes: Optional[torch.Tensor] = None
  317. text: Optional[str] = None
  318. def generate_long(
  319. *,
  320. model,
  321. device: str | torch.device,
  322. decode_one_token: callable,
  323. text: str,
  324. num_samples: int = 1,
  325. max_new_tokens: int = 0,
  326. top_p: int = 0.7,
  327. repetition_penalty: float = 1.5,
  328. temperature: float = 0.7,
  329. compile: bool = False,
  330. iterative_prompt: bool = True,
  331. max_length: int = 2048,
  332. chunk_length: int = 150,
  333. prompt_text: Optional[str | list[str]] = None,
  334. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  335. ):
  336. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  337. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  338. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  339. use_prompt = prompt_text is not None and prompt_tokens is not None
  340. if use_prompt and isinstance(prompt_text, str):
  341. prompt_text = [prompt_text]
  342. prompt_tokens = [prompt_tokens]
  343. assert use_prompt is False or len(prompt_text) == len(
  344. prompt_tokens
  345. ), "Prompt text and tokens must have the same length"
  346. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  347. tokenizer = model.tokenizer
  348. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  349. encoded = []
  350. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  351. encoded_prompts = []
  352. if use_prompt:
  353. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  354. encoded_prompts.append(
  355. encode_tokens(
  356. tokenizer,
  357. string=t,
  358. device=device,
  359. prompt_tokens=c,
  360. num_codebooks=model.config.num_codebooks,
  361. )
  362. )
  363. for idx, text in enumerate(texts):
  364. encoded.append(
  365. encode_tokens(
  366. tokenizer,
  367. string=text,
  368. device=device,
  369. num_codebooks=model.config.num_codebooks,
  370. )
  371. )
  372. logger.info(f"Encoded text: {text}")
  373. # Move temperature, top_p, repetition_penalty to device
  374. # This is important so that changing params doesn't trigger recompile
  375. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  376. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  377. repetition_penalty = torch.tensor(
  378. repetition_penalty, device=device, dtype=torch.float
  379. )
  380. for sample_idx in range(num_samples):
  381. if torch.cuda.is_available():
  382. torch.cuda.synchronize()
  383. global_encoded = []
  384. seg_idx = 0
  385. while seg_idx < len(encoded):
  386. logger.info(
  387. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  388. )
  389. seg = encoded[seg_idx]
  390. global_encoded.append(seg)
  391. lengths = reversed([seg.size(1) for seg in global_encoded])
  392. # Pick last 2000 tokens
  393. count = 0
  394. for i, length in enumerate(lengths):
  395. count += length
  396. if count + length > max_length - 1024 - sum(
  397. t.shape[1] for t in encoded_prompts
  398. ):
  399. break
  400. if i != 0 and i % 2 == 0:
  401. i -= 1
  402. # Rotate the list, always make sure first segment is included to avoid drift
  403. if i < len(global_encoded) - 2:
  404. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  405. else:
  406. partial_encoded = global_encoded
  407. if use_prompt:
  408. partial_encoded = encoded_prompts + partial_encoded
  409. cat_encoded = torch.cat(partial_encoded, dim=1)
  410. prompt_length = cat_encoded.size(1)
  411. t0 = time.perf_counter()
  412. y = generate(
  413. model=model,
  414. prompt=cat_encoded,
  415. max_new_tokens=max_new_tokens,
  416. im_end_id=im_end_id,
  417. decode_one_token=decode_one_token,
  418. temperature=temperature,
  419. top_p=top_p,
  420. repetition_penalty=repetition_penalty,
  421. )
  422. if sample_idx == 0 and seg_idx == 0 and compile:
  423. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  424. if torch.cuda.is_available():
  425. torch.cuda.synchronize()
  426. t = time.perf_counter() - t0
  427. tokens_generated = y.size(1) - prompt_length
  428. tokens_sec = tokens_generated / t
  429. logger.info(
  430. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  431. )
  432. logger.info(
  433. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  434. )
  435. if torch.cuda.is_available():
  436. logger.info(
  437. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  438. )
  439. # Put the generated tokens
  440. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  441. codes = y[1:, prompt_length:-1].clone()
  442. codes = codes - 1
  443. assert (codes >= 0).all(), f"Negative code found"
  444. decoded = y[:, prompt_length:-1].clone()
  445. # But for global encoding, we should keep the <im_end> token
  446. global_encoded.append(decoded)
  447. assert (codes >= 0).all(), f"Negative code found: {codes}"
  448. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  449. seg_idx += 1
  450. # This indicates the end of the current sample
  451. yield GenerateResponse(action="next")
  452. @dataclass
  453. class WrappedGenerateResponse:
  454. status: Literal["success", "error"]
  455. response: Optional[GenerateResponse | Exception] = None
  456. @dataclass
  457. class GenerateRequest:
  458. request: dict
  459. response_queue: queue.Queue
  460. def launch_thread_safe_queue(
  461. checkpoint_path,
  462. device,
  463. precision,
  464. compile: bool = False,
  465. ):
  466. input_queue = queue.Queue()
  467. init_event = threading.Event()
  468. def worker():
  469. model, decode_one_token = load_model(
  470. checkpoint_path, device, precision, compile=compile
  471. )
  472. init_event.set()
  473. while True:
  474. item: GenerateRequest | None = input_queue.get()
  475. if item is None:
  476. break
  477. kwargs = item.request
  478. response_queue = item.response_queue
  479. try:
  480. for chunk in generate_long(
  481. model=model, decode_one_token=decode_one_token, **kwargs
  482. ):
  483. response_queue.put(
  484. WrappedGenerateResponse(status="success", response=chunk)
  485. )
  486. except Exception as e:
  487. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  488. threading.Thread(target=worker, daemon=True).start()
  489. init_event.wait()
  490. return input_queue
  491. @click.command()
  492. @click.option(
  493. "--text",
  494. type=str,
  495. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  496. )
  497. @click.option("--prompt-text", type=str, default=None, multiple=True)
  498. @click.option(
  499. "--prompt-tokens",
  500. type=click.Path(path_type=Path, exists=True),
  501. default=None,
  502. multiple=True,
  503. )
  504. @click.option("--num-samples", type=int, default=1)
  505. @click.option("--max-new-tokens", type=int, default=0)
  506. @click.option("--top-p", type=float, default=0.7)
  507. @click.option("--repetition-penalty", type=float, default=1.2)
  508. @click.option("--temperature", type=float, default=0.7)
  509. @click.option(
  510. "--checkpoint-path",
  511. type=click.Path(path_type=Path, exists=True),
  512. default="checkpoints/fish-speech-1.2",
  513. )
  514. @click.option("--compile/--no-compile", default=False)
  515. @click.option("--seed", type=int, default=42)
  516. @click.option("--half/--no-half", default=False)
  517. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  518. @click.option("--chunk-length", type=int, default=100)
  519. def main(
  520. text: str,
  521. prompt_text: Optional[list[str]],
  522. prompt_tokens: Optional[list[Path]],
  523. num_samples: int,
  524. max_new_tokens: int,
  525. top_p: int,
  526. repetition_penalty: float,
  527. temperature: float,
  528. checkpoint_path: Path,
  529. compile: bool,
  530. seed: int,
  531. half: bool,
  532. iterative_prompt: bool,
  533. chunk_length: int,
  534. ) -> None:
  535. device = "cuda"
  536. precision = torch.half if half else torch.bfloat16
  537. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  538. raise ValueError(
  539. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  540. )
  541. logger.info("Loading model ...")
  542. t0 = time.time()
  543. model, decode_one_token = load_model(
  544. checkpoint_path, device, precision, compile=compile
  545. )
  546. if torch.cuda.is_available():
  547. torch.cuda.synchronize()
  548. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  549. if prompt_tokens is not None:
  550. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  551. torch.manual_seed(seed)
  552. if torch.cuda.is_available():
  553. torch.cuda.manual_seed(seed)
  554. generator = generate_long(
  555. model=model,
  556. device=device,
  557. decode_one_token=decode_one_token,
  558. text=text,
  559. num_samples=num_samples,
  560. max_new_tokens=max_new_tokens,
  561. top_p=top_p,
  562. repetition_penalty=repetition_penalty,
  563. temperature=temperature,
  564. compile=compile,
  565. iterative_prompt=iterative_prompt,
  566. chunk_length=chunk_length,
  567. prompt_text=prompt_text,
  568. prompt_tokens=prompt_tokens,
  569. )
  570. idx = 0
  571. codes = []
  572. for response in generator:
  573. if response.action == "sample":
  574. codes.append(response.codes)
  575. logger.info(f"Sampled text: {response.text}")
  576. elif response.action == "next":
  577. if codes:
  578. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  579. logger.info(f"Saved codes to codes_{idx}.npy")
  580. logger.info(f"Next sample")
  581. codes = []
  582. idx += 1
  583. else:
  584. logger.error(f"Error: {response}")
  585. if __name__ == "__main__":
  586. main()