generate.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import hydra
  11. import numpy as np
  12. import torch
  13. import torch._dynamo.config
  14. import torch._inductor.config
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  18. from fish_speech.text import clean_text, split_text
  19. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  20. torch._inductor.config.coordinate_descent_tuning = True
  21. torch._inductor.config.triton.unique_kernel_names = True
  22. if hasattr(torch._inductor.config, "fx_graph_cache"):
  23. # Experimental feature to reduce compilation times, will be on by default in future
  24. torch._inductor.config.fx_graph_cache = True
  25. from fish_speech.models.text2semantic.llama import (
  26. BaseTransformer,
  27. DualARTransformer,
  28. NaiveTransformer,
  29. )
  30. def multinomial_sample_one_no_sync(
  31. probs_sort,
  32. ): # Does multinomial sampling without a cuda synchronization
  33. q = torch.empty_like(probs_sort).exponential_(1)
  34. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  35. def logits_to_probs(
  36. logits,
  37. previous_tokens: Optional[torch.Tensor] = None,
  38. temperature: torch.Tensor = 1.0,
  39. top_p: torch.Tensor = 1.0,
  40. repetition_penalty: torch.Tensor = 1.0,
  41. ) -> torch.Tensor:
  42. # Apply repetition penalty
  43. if previous_tokens is not None:
  44. previous_tokens = previous_tokens.long()
  45. score = torch.gather(logits, dim=0, index=previous_tokens)
  46. score = torch.where(
  47. score < 0, score * repetition_penalty, score / repetition_penalty
  48. )
  49. logits.scatter_(dim=0, index=previous_tokens, src=score)
  50. # Apply top-p sampling
  51. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  52. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  53. sorted_indices_to_remove = cum_probs > top_p
  54. sorted_indices_to_remove[0] = False # keep at least one option
  55. indices_to_remove = sorted_indices_to_remove.scatter(
  56. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  57. )
  58. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  59. logits = logits / max(temperature, 1e-5)
  60. probs = torch.nn.functional.softmax(logits, dim=-1)
  61. return probs
  62. def sample(
  63. logits,
  64. previous_tokens: Optional[torch.Tensor] = None,
  65. **sampling_kwargs,
  66. ) -> Tuple[torch.Tensor, torch.Tensor]:
  67. probs = logits_to_probs(
  68. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  69. )
  70. idx_next = multinomial_sample_one_no_sync(probs)
  71. return idx_next, probs
  72. def decode_one_token_ar(
  73. model: DualARTransformer,
  74. x: torch.Tensor,
  75. input_pos: torch.Tensor,
  76. previous_tokens: torch.Tensor = None,
  77. **sampling_kwargs,
  78. ) -> torch.Tensor:
  79. x = model.forward_generate(x, input_pos)
  80. logits = x.token_logits
  81. sampling_kwargs_main = sampling_kwargs.copy()
  82. sampling_kwargs_main["temperature"] = 0.1
  83. sampling_kwargs_main["top_p"] = 0.1
  84. sampling_kwargs_main["repetition_penalty"] = 1.0
  85. codebooks = [
  86. sample(
  87. logits,
  88. previous_tokens=None, # Disable repetition penalty for the token codebook
  89. **sampling_kwargs_main,
  90. )[0]
  91. ]
  92. x = x.hidden_states
  93. # Cleanup the cache
  94. for layer in model.fast_layers:
  95. layer.attention.kv_cache.k_cache.fill_(0)
  96. layer.attention.kv_cache.v_cache.fill_(0)
  97. for codebook_idx in range(model.config.num_codebooks):
  98. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  99. logits = model.forward_generate_fast(x, input_pos)
  100. a = sample(
  101. logits,
  102. previous_tokens=(
  103. previous_tokens[codebook_idx + 1]
  104. if previous_tokens is not None
  105. else None
  106. ),
  107. **sampling_kwargs,
  108. )[0]
  109. x = model.fast_embeddings(a)
  110. codebooks.append(a)
  111. return torch.stack(codebooks, dim=0)
  112. def decode_one_token_naive(
  113. model: NaiveTransformer,
  114. x: torch.Tensor,
  115. input_pos: torch.Tensor,
  116. previous_tokens: torch.Tensor = None,
  117. **sampling_kwargs,
  118. ) -> torch.Tensor:
  119. x = model.forward_generate(x, input_pos)
  120. logits = x.token_logits
  121. sampling_kwargs_main = sampling_kwargs.copy()
  122. sampling_kwargs_main["temperature"] = 0.1
  123. sampling_kwargs_main["top_p"] = 0.1
  124. sampling_kwargs_main["repetition_penalty"] = 1.0
  125. codebooks = [
  126. sample(
  127. logits,
  128. previous_tokens=None, # Disable repetition penalty for the token codebook
  129. **sampling_kwargs_main,
  130. )[0]
  131. ]
  132. for i in range(model.config.num_codebooks):
  133. codebooks.append(
  134. sample(
  135. x.codebook_logits[:, :, i],
  136. previous_tokens=(
  137. previous_tokens[i + 1] if previous_tokens is not None else None
  138. ),
  139. **sampling_kwargs,
  140. )[0]
  141. )
  142. return torch.stack(codebooks, dim=0)
  143. def decode_n_tokens(
  144. model: NaiveTransformer,
  145. cur_token: torch.Tensor,
  146. input_pos: torch.Tensor,
  147. num_new_tokens: int,
  148. im_end_id: int = 4,
  149. decode_one_token=decode_one_token_naive,
  150. **sampling_kwargs,
  151. ):
  152. previous_tokens = torch.zeros(
  153. (model.config.num_codebooks + 1, model.config.max_seq_len),
  154. dtype=torch.int,
  155. device=cur_token.device,
  156. )
  157. for i in tqdm(range(num_new_tokens)):
  158. # We need to get windowed repeat penalty
  159. win_size = 16
  160. if i < win_size:
  161. window = previous_tokens[:, :win_size]
  162. else:
  163. window = previous_tokens[:, i - win_size : i]
  164. with (
  165. torch.backends.cuda.sdp_kernel(
  166. enable_flash=False, enable_mem_efficient=False, enable_math=True
  167. )
  168. if torch.cuda.is_available()
  169. else nullcontext()
  170. ): # Actually better for Inductor to codegen attention here
  171. next_token = decode_one_token(
  172. model=model,
  173. x=cur_token,
  174. input_pos=input_pos,
  175. previous_tokens=window,
  176. **sampling_kwargs,
  177. )
  178. input_pos += 1
  179. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  180. previous_tokens[:, i : i + 1] = next_token.view(
  181. model.config.num_codebooks + 1, -1
  182. )
  183. if cur_token[0, 0, -1] == im_end_id:
  184. break
  185. return previous_tokens[:, : i + 1]
  186. @torch.no_grad()
  187. @torch.inference_mode()
  188. def generate(
  189. *,
  190. model: NaiveTransformer,
  191. prompt: torch.Tensor,
  192. max_new_tokens: int,
  193. im_end_id: int = 4,
  194. decode_one_token=decode_one_token_naive,
  195. **sampling_kwargs,
  196. ) -> torch.Tensor:
  197. """
  198. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  199. """
  200. # create an empty tensor of the expected final shape and fill in the current tokens
  201. T = prompt.size(1)
  202. if max_new_tokens:
  203. if T + max_new_tokens > model.config.max_seq_len:
  204. max_new_tokens = model.config.max_seq_len - T
  205. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  206. T_new = T + max_new_tokens
  207. else:
  208. T_new = model.config.max_seq_len
  209. max_new_tokens = T_new - T
  210. device, dtype = prompt.device, prompt.dtype
  211. with torch.device(device):
  212. model.setup_caches(
  213. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  214. )
  215. codebook_dim = 1 + model.config.num_codebooks
  216. # create an empty tensor of the expected final shape and fill in the current tokens
  217. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  218. empty[:, :T] = prompt
  219. seq = empty
  220. input_pos = torch.arange(0, T, device=device)
  221. # Use non-accelerated version for now, to avoid compilation overhead
  222. prefill_decode = (
  223. decode_one_token_naive
  224. if isinstance(model, NaiveTransformer)
  225. else decode_one_token_ar
  226. )
  227. next_token = prefill_decode(
  228. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  229. )
  230. seq[:, T : T + 1] = next_token
  231. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  232. x = decode_n_tokens(
  233. model,
  234. next_token.view(1, codebook_dim, -1),
  235. input_pos,
  236. max_new_tokens - 1,
  237. im_end_id=im_end_id,
  238. decode_one_token=decode_one_token,
  239. **sampling_kwargs,
  240. )
  241. # x = torch.cat(generated_tokens, dim=1)
  242. seq = seq[:, : T + 1 + x.size(1)]
  243. seq[:, T + 1 :] = x
  244. return seq
  245. def encode_tokens(
  246. tokenizer,
  247. string,
  248. device="cuda",
  249. prompt_tokens=None,
  250. num_codebooks=4,
  251. ):
  252. string = clean_text(string)
  253. string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
  254. new_tokens = tokenizer.encode(
  255. string,
  256. add_special_tokens=False,
  257. max_length=10**6,
  258. truncation=False,
  259. )
  260. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  261. # Codebooks
  262. zeros = (
  263. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  264. * CODEBOOK_PAD_TOKEN_ID
  265. )
  266. prompt = torch.cat((tokens, zeros), dim=0)
  267. if prompt_tokens is None:
  268. return prompt
  269. # Get prompt tokens
  270. if prompt_tokens.ndim == 3:
  271. assert (
  272. prompt_tokens.shape[0] == 1
  273. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  274. prompt_tokens = prompt_tokens[0]
  275. assert prompt_tokens.ndim == 2
  276. data = prompt_tokens + 1
  277. if prompt_tokens.shape[0] > num_codebooks:
  278. logger.warning(
  279. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  280. )
  281. data = data[:num_codebooks]
  282. # Add pad token for each codebook
  283. data = torch.cat(
  284. (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
  285. dim=1,
  286. )
  287. # Since 1.0, we use <|semantic|>
  288. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  289. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  290. main_token_ids = (
  291. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  292. )
  293. main_token_ids[0, -1] = end_token_id
  294. data = torch.cat((main_token_ids, data), dim=0)
  295. prompt = torch.cat((prompt, data), dim=1)
  296. return prompt
  297. def load_model(checkpoint_path, device, precision, compile=False):
  298. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  299. checkpoint_path, load_weights=True
  300. )
  301. model = model.to(device=device, dtype=precision)
  302. logger.info(f"Restored model from checkpoint")
  303. if isinstance(model, DualARTransformer):
  304. decode_one_token = decode_one_token_ar
  305. logger.info("Using DualARTransformer")
  306. else:
  307. decode_one_token = decode_one_token_naive
  308. logger.info("Using NaiveTransformer")
  309. if compile:
  310. logger.info("Compiling function...")
  311. decode_one_token = torch.compile(
  312. decode_one_token,
  313. fullgraph=True,
  314. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  315. mode="reduce-overhead" if torch.cuda.is_available() else None,
  316. )
  317. return model.eval(), decode_one_token
  318. @dataclass
  319. class GenerateResponse:
  320. action: Literal["sample", "next"]
  321. codes: Optional[torch.Tensor] = None
  322. text: Optional[str] = None
  323. def generate_long(
  324. *,
  325. model,
  326. device: str | torch.device,
  327. decode_one_token: callable,
  328. text: str,
  329. num_samples: int = 1,
  330. max_new_tokens: int = 0,
  331. top_p: int = 0.7,
  332. repetition_penalty: float = 1.5,
  333. temperature: float = 0.7,
  334. compile: bool = False,
  335. iterative_prompt: bool = True,
  336. max_length: int = 2048,
  337. chunk_length: int = 150,
  338. prompt_text: Optional[str | list[str]] = None,
  339. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  340. ):
  341. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  342. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  343. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  344. use_prompt = prompt_text is not None and prompt_tokens is not None
  345. if use_prompt and isinstance(prompt_text, str):
  346. prompt_text = [prompt_text]
  347. prompt_tokens = [prompt_tokens]
  348. assert use_prompt is False or len(prompt_text) == len(
  349. prompt_tokens
  350. ), "Prompt text and tokens must have the same length"
  351. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  352. tokenizer = model.tokenizer
  353. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  354. encoded = []
  355. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  356. encoded_prompts = []
  357. if use_prompt:
  358. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  359. encoded_prompts.append(
  360. encode_tokens(
  361. tokenizer,
  362. string=t,
  363. device=device,
  364. prompt_tokens=c,
  365. num_codebooks=model.config.num_codebooks,
  366. )
  367. )
  368. for idx, text in enumerate(texts):
  369. encoded.append(
  370. encode_tokens(
  371. tokenizer,
  372. string=text,
  373. device=device,
  374. num_codebooks=model.config.num_codebooks,
  375. )
  376. )
  377. logger.info(f"Encoded text: {text}")
  378. # Move temperature, top_p, repetition_penalty to device
  379. # This is important so that changing params doesn't trigger recompile
  380. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  381. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  382. repetition_penalty = torch.tensor(
  383. repetition_penalty, device=device, dtype=torch.float
  384. )
  385. for sample_idx in range(num_samples):
  386. if torch.cuda.is_available():
  387. torch.cuda.synchronize()
  388. global_encoded = []
  389. seg_idx = 0
  390. while seg_idx < len(encoded):
  391. logger.info(
  392. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  393. )
  394. seg = encoded[seg_idx]
  395. global_encoded.append(seg)
  396. lengths = reversed([seg.size(1) for seg in global_encoded])
  397. # Pick last 2000 tokens
  398. count = 0
  399. for i, length in enumerate(lengths):
  400. count += length
  401. if count + length > max_length - 1024 - sum(
  402. t.shape[1] for t in encoded_prompts
  403. ):
  404. break
  405. if i != 0 and i % 2 == 0:
  406. i -= 1
  407. # Rotate the list, always make sure first segment is included to avoid drift
  408. if i < len(global_encoded) - 2:
  409. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  410. else:
  411. partial_encoded = global_encoded
  412. if use_prompt:
  413. partial_encoded = encoded_prompts + partial_encoded
  414. cat_encoded = torch.cat(partial_encoded, dim=1)
  415. prompt_length = cat_encoded.size(1)
  416. t0 = time.perf_counter()
  417. y = generate(
  418. model=model,
  419. prompt=cat_encoded,
  420. max_new_tokens=max_new_tokens,
  421. im_end_id=im_end_id,
  422. decode_one_token=decode_one_token,
  423. temperature=temperature,
  424. top_p=top_p,
  425. repetition_penalty=repetition_penalty,
  426. )
  427. if sample_idx == 0 and seg_idx == 0 and compile:
  428. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  429. if torch.cuda.is_available():
  430. torch.cuda.synchronize()
  431. t = time.perf_counter() - t0
  432. tokens_generated = y.size(1) - prompt_length
  433. tokens_sec = tokens_generated / t
  434. logger.info(
  435. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  436. )
  437. logger.info(
  438. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  439. )
  440. if torch.cuda.is_available():
  441. logger.info(
  442. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  443. )
  444. # Put the generated tokens
  445. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  446. codes = y[1:, prompt_length:-1].clone()
  447. codes = codes - 1
  448. assert (codes >= 0).all(), f"Negative code found"
  449. decoded = y[:, prompt_length:-1].clone()
  450. # But for global encoding, we should keep the <im_end> token
  451. global_encoded.append(decoded)
  452. assert (codes >= 0).all(), f"Negative code found: {codes}"
  453. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  454. seg_idx += 1
  455. # This indicates the end of the current sample
  456. yield GenerateResponse(action="next")
  457. @dataclass
  458. class WrappedGenerateResponse:
  459. status: Literal["success", "error"]
  460. response: Optional[GenerateResponse | Exception] = None
  461. @dataclass
  462. class GenerateRequest:
  463. request: dict
  464. response_queue: queue.Queue
  465. def launch_thread_safe_queue(
  466. checkpoint_path,
  467. device,
  468. precision,
  469. compile: bool = False,
  470. ):
  471. input_queue = queue.Queue()
  472. init_event = threading.Event()
  473. def worker():
  474. model, decode_one_token = load_model(
  475. checkpoint_path, device, precision, compile=compile
  476. )
  477. init_event.set()
  478. while True:
  479. item: GenerateRequest | None = input_queue.get()
  480. if item is None:
  481. break
  482. kwargs = item.request
  483. response_queue = item.response_queue
  484. try:
  485. for chunk in generate_long(
  486. model=model, decode_one_token=decode_one_token, **kwargs
  487. ):
  488. response_queue.put(
  489. WrappedGenerateResponse(status="success", response=chunk)
  490. )
  491. except Exception as e:
  492. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  493. threading.Thread(target=worker, daemon=True).start()
  494. init_event.wait()
  495. return input_queue
  496. @click.command()
  497. @click.option(
  498. "--text",
  499. type=str,
  500. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  501. )
  502. @click.option("--prompt-text", type=str, default=None, multiple=True)
  503. @click.option(
  504. "--prompt-tokens",
  505. type=click.Path(path_type=Path, exists=True),
  506. default=None,
  507. multiple=True,
  508. )
  509. @click.option("--num-samples", type=int, default=1)
  510. @click.option("--max-new-tokens", type=int, default=0)
  511. @click.option("--top-p", type=float, default=0.7)
  512. @click.option("--repetition-penalty", type=float, default=1.2)
  513. @click.option("--temperature", type=float, default=0.7)
  514. @click.option(
  515. "--checkpoint-path",
  516. type=click.Path(path_type=Path, exists=True),
  517. default="checkpoints/fish-speech-1.4",
  518. )
  519. @click.option("--device", type=str, default="cuda")
  520. @click.option("--compile/--no-compile", default=False)
  521. @click.option("--seed", type=int, default=42)
  522. @click.option("--half/--no-half", default=False)
  523. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  524. @click.option("--chunk-length", type=int, default=100)
  525. def main(
  526. text: str,
  527. prompt_text: Optional[list[str]],
  528. prompt_tokens: Optional[list[Path]],
  529. num_samples: int,
  530. max_new_tokens: int,
  531. top_p: int,
  532. repetition_penalty: float,
  533. temperature: float,
  534. checkpoint_path: Path,
  535. device: str,
  536. compile: bool,
  537. seed: int,
  538. half: bool,
  539. iterative_prompt: bool,
  540. chunk_length: int,
  541. ) -> None:
  542. precision = torch.half if half else torch.bfloat16
  543. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  544. raise ValueError(
  545. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  546. )
  547. logger.info("Loading model ...")
  548. t0 = time.time()
  549. model, decode_one_token = load_model(
  550. checkpoint_path, device, precision, compile=compile
  551. )
  552. if torch.cuda.is_available():
  553. torch.cuda.synchronize()
  554. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  555. if prompt_tokens is not None:
  556. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  557. torch.manual_seed(seed)
  558. if torch.cuda.is_available():
  559. torch.cuda.manual_seed(seed)
  560. generator = generate_long(
  561. model=model,
  562. device=device,
  563. decode_one_token=decode_one_token,
  564. text=text,
  565. num_samples=num_samples,
  566. max_new_tokens=max_new_tokens,
  567. top_p=top_p,
  568. repetition_penalty=repetition_penalty,
  569. temperature=temperature,
  570. compile=compile,
  571. iterative_prompt=iterative_prompt,
  572. chunk_length=chunk_length,
  573. prompt_text=prompt_text,
  574. prompt_tokens=prompt_tokens,
  575. )
  576. idx = 0
  577. codes = []
  578. for response in generator:
  579. if response.action == "sample":
  580. codes.append(response.codes)
  581. logger.info(f"Sampled text: {response.text}")
  582. elif response.action == "next":
  583. if codes:
  584. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  585. logger.info(f"Saved codes to codes_{idx}.npy")
  586. logger.info(f"Next sample")
  587. codes = []
  588. idx += 1
  589. else:
  590. logger.error(f"Error: {response}")
  591. if __name__ == "__main__":
  592. main()