generate.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import hydra
  11. import numpy as np
  12. import torch
  13. import torch._dynamo.config
  14. import torch._inductor.config
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  18. from fish_speech.text import clean_text, split_text
  19. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  20. torch._inductor.config.coordinate_descent_tuning = True
  21. torch._inductor.config.triton.unique_kernel_names = True
  22. if hasattr(torch._inductor.config, "fx_graph_cache"):
  23. # Experimental feature to reduce compilation times, will be on by default in future
  24. torch._inductor.config.fx_graph_cache = True
  25. from fish_speech.models.text2semantic.llama import (
  26. BaseTransformer,
  27. DualARTransformer,
  28. NaiveTransformer,
  29. )
  30. def multinomial_sample_one_no_sync(
  31. probs_sort,
  32. ): # Does multinomial sampling without a cuda synchronization
  33. q = torch.empty_like(probs_sort).exponential_(1)
  34. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  35. def logits_to_probs(
  36. logits,
  37. previous_tokens: Optional[torch.Tensor] = None,
  38. temperature: torch.Tensor = 1.0,
  39. top_p: torch.Tensor = 1.0,
  40. repetition_penalty: torch.Tensor = 1.0,
  41. ) -> torch.Tensor:
  42. # Apply repetition penalty
  43. if previous_tokens is not None:
  44. previous_tokens = previous_tokens.long()
  45. score = torch.gather(logits, dim=0, index=previous_tokens)
  46. score = torch.where(
  47. score < 0, score * repetition_penalty, score / repetition_penalty
  48. )
  49. logits.scatter_(dim=0, index=previous_tokens, src=score)
  50. # Apply top-p sampling
  51. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  52. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  53. sorted_indices_to_remove = cum_probs > top_p
  54. sorted_indices_to_remove[0] = False # keep at least one option
  55. indices_to_remove = sorted_indices_to_remove.scatter(
  56. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  57. )
  58. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  59. logits = logits / max(temperature, 1e-5)
  60. probs = torch.nn.functional.softmax(logits, dim=-1)
  61. return probs
  62. def sample(
  63. logits,
  64. previous_tokens: Optional[torch.Tensor] = None,
  65. **sampling_kwargs,
  66. ) -> Tuple[torch.Tensor, torch.Tensor]:
  67. probs = logits_to_probs(
  68. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  69. )
  70. idx_next = multinomial_sample_one_no_sync(probs)
  71. return idx_next, probs
  72. def decode_one_token_ar(
  73. model: DualARTransformer,
  74. x: torch.Tensor,
  75. input_pos: torch.Tensor,
  76. previous_tokens: torch.Tensor = None,
  77. **sampling_kwargs,
  78. ) -> torch.Tensor:
  79. x = model.forward_generate(x, input_pos)
  80. sampling_kwargs_main = sampling_kwargs.copy()
  81. sampling_kwargs_main["temperature"] = 0.1
  82. sampling_kwargs_main["top_p"] = 0.1
  83. sampling_kwargs_main["repetition_penalty"] = 1.0
  84. codebooks = [
  85. sample(
  86. logits,
  87. previous_tokens=None, # Disable repetition penalty for the token codebook
  88. **sampling_kwargs_main,
  89. )[0]
  90. ]
  91. x = x.hidden_states
  92. # Cleanup the cache
  93. for layer in model.fast_layers:
  94. layer.attention.kv_cache.k_cache.fill_(0)
  95. layer.attention.kv_cache.v_cache.fill_(0)
  96. for codebook_idx in range(model.config.num_codebooks):
  97. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  98. logits = model.forward_generate_fast(x, input_pos)
  99. a = sample(
  100. logits,
  101. previous_tokens=(
  102. previous_tokens[codebook_idx + 1]
  103. if previous_tokens is not None
  104. else None
  105. ),
  106. **sampling_kwargs,
  107. )[0]
  108. x = model.fast_embeddings(a)
  109. codebooks.append(a)
  110. return torch.stack(codebooks, dim=0)
  111. def decode_one_token_naive(
  112. model: NaiveTransformer,
  113. x: torch.Tensor,
  114. input_pos: torch.Tensor,
  115. previous_tokens: torch.Tensor = None,
  116. **sampling_kwargs,
  117. ) -> torch.Tensor:
  118. x = model.forward_generate(x, input_pos)
  119. logits = x.token_logits
  120. sampling_kwargs_main = sampling_kwargs.copy()
  121. sampling_kwargs_main["temperature"] = 0.1
  122. sampling_kwargs_main["top_p"] = 0.1
  123. sampling_kwargs_main["repetition_penalty"] = 1.0
  124. codebooks = [
  125. sample(
  126. logits,
  127. previous_tokens=None, # Disable repetition penalty for the token codebook
  128. **sampling_kwargs_main,
  129. )[0]
  130. ]
  131. for i in range(model.config.num_codebooks):
  132. codebooks.append(
  133. sample(
  134. x.codebook_logits[:, :, i],
  135. previous_tokens=(
  136. previous_tokens[i + 1] if previous_tokens is not None else None
  137. ),
  138. **sampling_kwargs,
  139. )[0]
  140. )
  141. return torch.stack(codebooks, dim=0)
  142. def decode_n_tokens(
  143. model: NaiveTransformer,
  144. cur_token: torch.Tensor,
  145. input_pos: torch.Tensor,
  146. num_new_tokens: int,
  147. im_end_id: int = 4,
  148. decode_one_token=decode_one_token_naive,
  149. **sampling_kwargs,
  150. ):
  151. previous_tokens = torch.zeros(
  152. (model.config.num_codebooks + 1, model.config.max_seq_len),
  153. dtype=torch.int,
  154. device=cur_token.device,
  155. )
  156. for i in tqdm(range(num_new_tokens)):
  157. # We need to get windowed repeat penalty
  158. win_size = 16
  159. if i < win_size:
  160. window = previous_tokens[:, :win_size]
  161. else:
  162. window = previous_tokens[:, i - win_size : i]
  163. with (
  164. torch.backends.cuda.sdp_kernel(
  165. enable_flash=False, enable_mem_efficient=False, enable_math=True
  166. )
  167. if torch.cuda.is_available()
  168. else nullcontext()
  169. ): # Actually better for Inductor to codegen attention here
  170. next_token = decode_one_token(
  171. model=model,
  172. x=cur_token,
  173. input_pos=input_pos,
  174. previous_tokens=window,
  175. **sampling_kwargs,
  176. )
  177. input_pos += 1
  178. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  179. previous_tokens[:, i : i + 1] = next_token.view(
  180. model.config.num_codebooks + 1, -1
  181. )
  182. if cur_token[0, 0, -1] == im_end_id:
  183. break
  184. return previous_tokens[:, : i + 1]
  185. @torch.no_grad()
  186. @torch.inference_mode()
  187. def generate(
  188. *,
  189. model: NaiveTransformer,
  190. prompt: torch.Tensor,
  191. max_new_tokens: int,
  192. im_end_id: int = 4,
  193. decode_one_token=decode_one_token_naive,
  194. **sampling_kwargs,
  195. ) -> torch.Tensor:
  196. """
  197. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  198. """
  199. # create an empty tensor of the expected final shape and fill in the current tokens
  200. T = prompt.size(1)
  201. if max_new_tokens:
  202. if T + max_new_tokens > model.config.max_seq_len:
  203. max_new_tokens = model.config.max_seq_len - T
  204. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  205. T_new = T + max_new_tokens
  206. else:
  207. T_new = model.config.max_seq_len
  208. max_new_tokens = T_new - T
  209. device, dtype = prompt.device, prompt.dtype
  210. with torch.device(device):
  211. model.setup_caches(
  212. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  213. )
  214. codebook_dim = 1 + model.config.num_codebooks
  215. # create an empty tensor of the expected final shape and fill in the current tokens
  216. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  217. empty[:, :T] = prompt
  218. seq = empty
  219. input_pos = torch.arange(0, T, device=device)
  220. # Use non-accelerated version for now, to avoid compilation overhead
  221. prefill_decode = (
  222. decode_one_token_naive
  223. if isinstance(model, NaiveTransformer)
  224. else decode_one_token_ar
  225. )
  226. next_token = prefill_decode(
  227. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  228. )
  229. seq[:, T : T + 1] = next_token
  230. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  231. x = decode_n_tokens(
  232. model,
  233. next_token.view(1, codebook_dim, -1),
  234. input_pos,
  235. max_new_tokens - 1,
  236. im_end_id=im_end_id,
  237. decode_one_token=decode_one_token,
  238. **sampling_kwargs,
  239. )
  240. # x = torch.cat(generated_tokens, dim=1)
  241. seq = seq[:, : T + 1 + x.size(1)]
  242. seq[:, T + 1 :] = x
  243. return seq
  244. def encode_tokens(
  245. tokenizer,
  246. string,
  247. device="cuda",
  248. prompt_tokens=None,
  249. num_codebooks=4,
  250. ):
  251. string = clean_text(string)
  252. string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
  253. new_tokens = tokenizer.encode(
  254. string,
  255. add_special_tokens=False,
  256. max_length=10**6,
  257. truncation=False,
  258. )
  259. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  260. # Codebooks
  261. zeros = (
  262. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  263. * CODEBOOK_PAD_TOKEN_ID
  264. )
  265. prompt = torch.cat((tokens, zeros), dim=0)
  266. if prompt_tokens is None:
  267. return prompt
  268. # Get prompt tokens
  269. if prompt_tokens.ndim == 3:
  270. assert (
  271. prompt_tokens.shape[0] == 1
  272. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  273. prompt_tokens = prompt_tokens[0]
  274. assert prompt_tokens.ndim == 2
  275. data = prompt_tokens + 1
  276. if prompt_tokens.shape[0] > num_codebooks:
  277. logger.warning(
  278. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  279. )
  280. data = data[:num_codebooks]
  281. # Add pad token for each codebook
  282. data = torch.cat(
  283. (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
  284. dim=1,
  285. )
  286. # Since 1.0, we use <|semantic|>
  287. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  288. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  289. main_token_ids = (
  290. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  291. )
  292. main_token_ids[0, -1] = end_token_id
  293. data = torch.cat((main_token_ids, data), dim=0)
  294. prompt = torch.cat((prompt, data), dim=1)
  295. return prompt
  296. def load_model(checkpoint_path, device, precision, compile=False):
  297. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  298. checkpoint_path, load_weights=True
  299. )
  300. model = model.to(device=device, dtype=precision)
  301. logger.info(f"Restored model from checkpoint")
  302. if isinstance(model, DualARTransformer):
  303. decode_one_token = decode_one_token_ar
  304. logger.info("Using DualARTransformer")
  305. else:
  306. decode_one_token = decode_one_token_naive
  307. logger.info("Using NaiveTransformer")
  308. if compile:
  309. logger.info("Compiling function...")
  310. decode_one_token = torch.compile(
  311. decode_one_token,
  312. fullgraph=True,
  313. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  314. mode="reduce-overhead" if torch.cuda.is_available() else None,
  315. )
  316. return model.eval(), decode_one_token
  317. @dataclass
  318. class GenerateResponse:
  319. action: Literal["sample", "next"]
  320. codes: Optional[torch.Tensor] = None
  321. text: Optional[str] = None
  322. def generate_long(
  323. *,
  324. model,
  325. device: str | torch.device,
  326. decode_one_token: callable,
  327. text: str,
  328. num_samples: int = 1,
  329. max_new_tokens: int = 0,
  330. top_p: int = 0.7,
  331. repetition_penalty: float = 1.5,
  332. temperature: float = 0.7,
  333. compile: bool = False,
  334. iterative_prompt: bool = True,
  335. max_length: int = 2048,
  336. chunk_length: int = 150,
  337. prompt_text: Optional[str | list[str]] = None,
  338. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  339. ):
  340. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  341. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  342. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  343. use_prompt = prompt_text is not None and prompt_tokens is not None
  344. if use_prompt and isinstance(prompt_text, str):
  345. prompt_text = [prompt_text]
  346. prompt_tokens = [prompt_tokens]
  347. assert use_prompt is False or len(prompt_text) == len(
  348. prompt_tokens
  349. ), "Prompt text and tokens must have the same length"
  350. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  351. tokenizer = model.tokenizer
  352. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  353. encoded = []
  354. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  355. encoded_prompts = []
  356. if use_prompt:
  357. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  358. encoded_prompts.append(
  359. encode_tokens(
  360. tokenizer,
  361. string=t,
  362. device=device,
  363. prompt_tokens=c,
  364. num_codebooks=model.config.num_codebooks,
  365. )
  366. )
  367. for idx, text in enumerate(texts):
  368. encoded.append(
  369. encode_tokens(
  370. tokenizer,
  371. string=text,
  372. device=device,
  373. num_codebooks=model.config.num_codebooks,
  374. )
  375. )
  376. logger.info(f"Encoded text: {text}")
  377. # Move temperature, top_p, repetition_penalty to device
  378. # This is important so that changing params doesn't trigger recompile
  379. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  380. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  381. repetition_penalty = torch.tensor(
  382. repetition_penalty, device=device, dtype=torch.float
  383. )
  384. for sample_idx in range(num_samples):
  385. if torch.cuda.is_available():
  386. torch.cuda.synchronize()
  387. global_encoded = []
  388. seg_idx = 0
  389. while seg_idx < len(encoded):
  390. logger.info(
  391. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  392. )
  393. seg = encoded[seg_idx]
  394. global_encoded.append(seg)
  395. lengths = reversed([seg.size(1) for seg in global_encoded])
  396. # Pick last 2000 tokens
  397. count = 0
  398. for i, length in enumerate(lengths):
  399. count += length
  400. if count + length > max_length - 1024 - sum(
  401. t.shape[1] for t in encoded_prompts
  402. ):
  403. break
  404. if i != 0 and i % 2 == 0:
  405. i -= 1
  406. # Rotate the list, always make sure first segment is included to avoid drift
  407. if i < len(global_encoded) - 2:
  408. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  409. else:
  410. partial_encoded = global_encoded
  411. if use_prompt:
  412. partial_encoded = encoded_prompts + partial_encoded
  413. cat_encoded = torch.cat(partial_encoded, dim=1)
  414. prompt_length = cat_encoded.size(1)
  415. t0 = time.perf_counter()
  416. y = generate(
  417. model=model,
  418. prompt=cat_encoded,
  419. max_new_tokens=max_new_tokens,
  420. im_end_id=im_end_id,
  421. decode_one_token=decode_one_token,
  422. temperature=temperature,
  423. top_p=top_p,
  424. repetition_penalty=repetition_penalty,
  425. )
  426. if sample_idx == 0 and seg_idx == 0 and compile:
  427. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  428. if torch.cuda.is_available():
  429. torch.cuda.synchronize()
  430. t = time.perf_counter() - t0
  431. tokens_generated = y.size(1) - prompt_length
  432. tokens_sec = tokens_generated / t
  433. logger.info(
  434. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  435. )
  436. logger.info(
  437. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  438. )
  439. if torch.cuda.is_available():
  440. logger.info(
  441. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  442. )
  443. # Put the generated tokens
  444. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  445. codes = y[1:, prompt_length:-1].clone()
  446. codes = codes - 1
  447. assert (codes >= 0).all(), f"Negative code found"
  448. decoded = y[:, prompt_length:-1].clone()
  449. # But for global encoding, we should keep the <im_end> token
  450. global_encoded.append(decoded)
  451. assert (codes >= 0).all(), f"Negative code found: {codes}"
  452. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  453. seg_idx += 1
  454. # This indicates the end of the current sample
  455. yield GenerateResponse(action="next")
  456. @dataclass
  457. class WrappedGenerateResponse:
  458. status: Literal["success", "error"]
  459. response: Optional[GenerateResponse | Exception] = None
  460. @dataclass
  461. class GenerateRequest:
  462. request: dict
  463. response_queue: queue.Queue
  464. def launch_thread_safe_queue(
  465. checkpoint_path,
  466. device,
  467. precision,
  468. compile: bool = False,
  469. ):
  470. input_queue = queue.Queue()
  471. init_event = threading.Event()
  472. def worker():
  473. model, decode_one_token = load_model(
  474. checkpoint_path, device, precision, compile=compile
  475. )
  476. init_event.set()
  477. while True:
  478. item: GenerateRequest | None = input_queue.get()
  479. if item is None:
  480. break
  481. kwargs = item.request
  482. response_queue = item.response_queue
  483. try:
  484. for chunk in generate_long(
  485. model=model, decode_one_token=decode_one_token, **kwargs
  486. ):
  487. response_queue.put(
  488. WrappedGenerateResponse(status="success", response=chunk)
  489. )
  490. except Exception as e:
  491. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  492. threading.Thread(target=worker, daemon=True).start()
  493. init_event.wait()
  494. return input_queue
  495. @click.command()
  496. @click.option(
  497. "--text",
  498. type=str,
  499. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  500. )
  501. @click.option("--prompt-text", type=str, default=None, multiple=True)
  502. @click.option(
  503. "--prompt-tokens",
  504. type=click.Path(path_type=Path, exists=True),
  505. default=None,
  506. multiple=True,
  507. )
  508. @click.option("--num-samples", type=int, default=1)
  509. @click.option("--max-new-tokens", type=int, default=0)
  510. @click.option("--top-p", type=float, default=0.7)
  511. @click.option("--repetition-penalty", type=float, default=1.2)
  512. @click.option("--temperature", type=float, default=0.7)
  513. @click.option(
  514. "--checkpoint-path",
  515. type=click.Path(path_type=Path, exists=True),
  516. default="checkpoints/fish-speech-1.4",
  517. )
  518. @click.option("--device", type=str, default="cuda")
  519. @click.option("--compile/--no-compile", default=False)
  520. @click.option("--seed", type=int, default=42)
  521. @click.option("--half/--no-half", default=False)
  522. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  523. @click.option("--chunk-length", type=int, default=100)
  524. def main(
  525. text: str,
  526. prompt_text: Optional[list[str]],
  527. prompt_tokens: Optional[list[Path]],
  528. num_samples: int,
  529. max_new_tokens: int,
  530. top_p: int,
  531. repetition_penalty: float,
  532. temperature: float,
  533. checkpoint_path: Path,
  534. device: str,
  535. compile: bool,
  536. seed: int,
  537. half: bool,
  538. iterative_prompt: bool,
  539. chunk_length: int,
  540. ) -> None:
  541. precision = torch.half if half else torch.bfloat16
  542. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  543. raise ValueError(
  544. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  545. )
  546. logger.info("Loading model ...")
  547. t0 = time.time()
  548. model, decode_one_token = load_model(
  549. checkpoint_path, device, precision, compile=compile
  550. )
  551. if torch.cuda.is_available():
  552. torch.cuda.synchronize()
  553. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  554. if prompt_tokens is not None:
  555. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  556. torch.manual_seed(seed)
  557. if torch.cuda.is_available():
  558. torch.cuda.manual_seed(seed)
  559. generator = generate_long(
  560. model=model,
  561. device=device,
  562. decode_one_token=decode_one_token,
  563. text=text,
  564. num_samples=num_samples,
  565. max_new_tokens=max_new_tokens,
  566. top_p=top_p,
  567. repetition_penalty=repetition_penalty,
  568. temperature=temperature,
  569. compile=compile,
  570. iterative_prompt=iterative_prompt,
  571. chunk_length=chunk_length,
  572. prompt_text=prompt_text,
  573. prompt_tokens=prompt_tokens,
  574. )
  575. idx = 0
  576. codes = []
  577. for response in generator:
  578. if response.action == "sample":
  579. codes.append(response.codes)
  580. logger.info(f"Sampled text: {response.text}")
  581. elif response.action == "next":
  582. if codes:
  583. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  584. logger.info(f"Saved codes to codes_{idx}.npy")
  585. logger.info(f"Next sample")
  586. codes = []
  587. idx += 1
  588. else:
  589. logger.error(f"Error: {response}")
  590. if __name__ == "__main__":
  591. main()