generate.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713
  1. import os
  2. import queue
  3. import string
  4. import threading
  5. import time
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import hydra
  11. import numpy as np
  12. import torch
  13. import torch._dynamo.config
  14. import torch._inductor.config
  15. from hydra import compose, initialize
  16. from hydra.utils import instantiate
  17. from loguru import logger
  18. from tqdm import tqdm
  19. from transformers import AutoTokenizer
  20. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  21. from fish_speech.text import clean_text, split_text
  22. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  23. torch._inductor.config.coordinate_descent_tuning = True
  24. torch._inductor.config.triton.unique_kernel_names = True
  25. if hasattr(torch._inductor.config, "fx_graph_cache"):
  26. # Experimental feature to reduce compilation times, will be on by default in future
  27. torch._inductor.config.fx_graph_cache = True
  28. from fish_speech.models.text2semantic.llama import (
  29. BaseTransformer,
  30. DualARTransformer,
  31. NaiveTransformer,
  32. )
  33. def multinomial_sample_one_no_sync(
  34. probs_sort,
  35. ): # Does multinomial sampling without a cuda synchronization
  36. q = torch.empty_like(probs_sort).exponential_(1)
  37. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  38. def logits_to_probs(
  39. logits,
  40. previous_tokens: Optional[torch.Tensor] = None,
  41. temperature: torch.Tensor = 1.0,
  42. top_p: torch.Tensor = 1.0,
  43. repetition_penalty: torch.Tensor = 1.0,
  44. ) -> torch.Tensor:
  45. # Apply repetition penalty
  46. if previous_tokens is not None:
  47. previous_tokens = previous_tokens.long()
  48. score = torch.gather(logits, dim=0, index=previous_tokens)
  49. score = torch.where(
  50. score < 0, score * repetition_penalty, score / repetition_penalty
  51. )
  52. logits.scatter_(dim=0, index=previous_tokens, src=score)
  53. # Apply top-p sampling
  54. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  55. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  56. sorted_indices_to_remove = cum_probs > top_p
  57. sorted_indices_to_remove[0] = False # keep at least one option
  58. indices_to_remove = sorted_indices_to_remove.scatter(
  59. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  60. )
  61. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  62. logits = logits / max(temperature, 1e-5)
  63. probs = torch.nn.functional.softmax(logits, dim=-1)
  64. return probs
  65. def sample(
  66. logits,
  67. previous_tokens: Optional[torch.Tensor] = None,
  68. **sampling_kwargs,
  69. ) -> Tuple[torch.Tensor, torch.Tensor]:
  70. probs = logits_to_probs(
  71. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  72. )
  73. idx_next = multinomial_sample_one_no_sync(probs)
  74. return idx_next, probs
  75. def decode_one_token_ar(
  76. model: DualARTransformer,
  77. x: torch.Tensor,
  78. input_pos: torch.Tensor,
  79. previous_tokens: torch.Tensor = None,
  80. **sampling_kwargs,
  81. ) -> torch.Tensor:
  82. x = model.forward_generate(x, input_pos)
  83. codebooks = [
  84. sample(
  85. x.logits,
  86. previous_tokens=None, # Disable repetition penalty for the token codebook
  87. **sampling_kwargs,
  88. )[0]
  89. ]
  90. x = x.hidden_states
  91. # Cleanup the cache
  92. for layer in model.fast_layers:
  93. layer.attention.kv_cache.k_cache.fill_(0)
  94. layer.attention.kv_cache.v_cache.fill_(0)
  95. for codebook_idx in range(model.config.num_codebooks):
  96. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  97. logits = model.forward_generate_fast(x, input_pos)
  98. a = sample(
  99. logits,
  100. previous_tokens=(
  101. previous_tokens[codebook_idx + 1]
  102. if previous_tokens is not None
  103. else None
  104. ),
  105. **sampling_kwargs,
  106. )[0]
  107. x = model.fast_embeddings(a)
  108. codebooks.append(a)
  109. return torch.stack(codebooks, dim=0)
  110. def decode_one_token_naive(
  111. model: NaiveTransformer,
  112. x: torch.Tensor,
  113. input_pos: torch.Tensor,
  114. previous_tokens: torch.Tensor = None,
  115. **sampling_kwargs,
  116. ) -> torch.Tensor:
  117. x = model.forward_generate(x, input_pos)
  118. codebooks = [
  119. sample(
  120. x.token_logits,
  121. previous_tokens=None, # Disable repetition penalty for the token codebook
  122. **sampling_kwargs,
  123. )[0]
  124. ]
  125. for i in range(model.config.num_codebooks):
  126. codebooks.append(
  127. sample(
  128. x.codebook_logits[:, :, i],
  129. previous_tokens=(
  130. previous_tokens[i + 1] if previous_tokens is not None else None
  131. ),
  132. **sampling_kwargs,
  133. )[0]
  134. )
  135. return torch.stack(codebooks, dim=0)
  136. def decode_n_tokens(
  137. model: NaiveTransformer,
  138. cur_token: torch.Tensor,
  139. input_pos: torch.Tensor,
  140. num_new_tokens: int,
  141. im_end_id: int = 4,
  142. decode_one_token=decode_one_token_naive,
  143. **sampling_kwargs,
  144. ):
  145. previous_tokens = torch.zeros(
  146. (model.config.num_codebooks + 1, model.config.max_seq_len),
  147. dtype=torch.int,
  148. device=cur_token.device,
  149. )
  150. for i in tqdm(range(num_new_tokens)):
  151. # We need to get windowed repeat penalty
  152. win_size = 16
  153. if i < win_size:
  154. window = previous_tokens[:, :win_size]
  155. else:
  156. window = previous_tokens[:, i - win_size : i]
  157. with torch.backends.cuda.sdp_kernel(
  158. enable_flash=False, enable_mem_efficient=False, enable_math=True
  159. ): # Actually better for Inductor to codegen attention here
  160. next_token = decode_one_token(
  161. model=model,
  162. x=cur_token,
  163. input_pos=input_pos,
  164. previous_tokens=window,
  165. **sampling_kwargs,
  166. )
  167. input_pos += 1
  168. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  169. previous_tokens[:, i : i + 1] = next_token.view(
  170. model.config.num_codebooks + 1, -1
  171. )
  172. if cur_token[0, 0, -1] == im_end_id:
  173. break
  174. return previous_tokens[:, : i + 1]
  175. @torch.no_grad()
  176. @torch.inference_mode()
  177. def generate(
  178. *,
  179. model: NaiveTransformer,
  180. prompt: torch.Tensor,
  181. max_new_tokens: int,
  182. im_end_id: int = 4,
  183. decode_one_token=decode_one_token_naive,
  184. **sampling_kwargs,
  185. ) -> torch.Tensor:
  186. """
  187. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  188. """
  189. # create an empty tensor of the expected final shape and fill in the current tokens
  190. T = prompt.size(1)
  191. if max_new_tokens:
  192. if T + max_new_tokens > model.config.max_seq_len:
  193. max_new_tokens = model.config.max_seq_len - T
  194. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  195. T_new = T + max_new_tokens
  196. else:
  197. T_new = model.config.max_seq_len
  198. max_new_tokens = T_new - T
  199. device, dtype = prompt.device, prompt.dtype
  200. with torch.device(device):
  201. model.setup_caches(
  202. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  203. )
  204. codebook_dim = 1 + model.config.num_codebooks
  205. # create an empty tensor of the expected final shape and fill in the current tokens
  206. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  207. empty[:, :T] = prompt
  208. seq = empty
  209. input_pos = torch.arange(0, T, device=device)
  210. # Use non-accelerated version for now, to avoid compilation overhead
  211. prefill_decode = (
  212. decode_one_token_naive
  213. if isinstance(model, NaiveTransformer)
  214. else decode_one_token_ar
  215. )
  216. next_token = prefill_decode(
  217. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  218. )
  219. seq[:, T : T + 1] = next_token
  220. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  221. x = decode_n_tokens(
  222. model,
  223. next_token.view(1, codebook_dim, -1),
  224. input_pos,
  225. max_new_tokens - 1,
  226. im_end_id=im_end_id,
  227. decode_one_token=decode_one_token,
  228. **sampling_kwargs,
  229. )
  230. # x = torch.cat(generated_tokens, dim=1)
  231. seq = seq[:, : T + 1 + x.size(1)]
  232. seq[:, T + 1 :] = x
  233. return seq
  234. def encode_tokens(
  235. tokenizer,
  236. string,
  237. device="cuda",
  238. prompt_tokens=None,
  239. num_codebooks=4,
  240. ):
  241. string = clean_text(string)
  242. string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
  243. new_tokens = tokenizer.encode(
  244. string,
  245. add_special_tokens=False,
  246. max_length=10**6,
  247. truncation=False,
  248. )
  249. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  250. # Codebooks
  251. zeros = (
  252. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  253. * CODEBOOK_PAD_TOKEN_ID
  254. )
  255. prompt = torch.cat((tokens, zeros), dim=0)
  256. if prompt_tokens is None:
  257. return prompt
  258. # Get prompt tokens
  259. if prompt_tokens.ndim == 3:
  260. assert (
  261. prompt_tokens.shape[0] == 1
  262. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  263. prompt_tokens = prompt_tokens[0]
  264. assert prompt_tokens.ndim == 2
  265. data = prompt_tokens + 1
  266. if prompt_tokens.shape[0] > num_codebooks:
  267. logger.warning(
  268. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  269. )
  270. data = data[:num_codebooks]
  271. # Add pad token for each codebook
  272. data = torch.cat(
  273. (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
  274. dim=1,
  275. )
  276. # Since 1.0, we use <|semantic|>
  277. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  278. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  279. main_token_ids = (
  280. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  281. )
  282. main_token_ids[0, -1] = end_token_id
  283. data = torch.cat((main_token_ids, data), dim=0)
  284. prompt = torch.cat((prompt, data), dim=1)
  285. return prompt
  286. def load_model(checkpoint_path, device, precision, compile=False):
  287. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  288. checkpoint_path, load_weights=True
  289. )
  290. if "int8" in str(checkpoint_path):
  291. logger.info("Using int8 weight-only quantization!")
  292. from .quantize import WeightOnlyInt8QuantHandler
  293. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  294. model = simple_quantizer.convert_for_runtime()
  295. if "int4" in str(checkpoint_path):
  296. logger.info("Using int4 quantization!")
  297. path_comps = checkpoint_path.name.split(".")
  298. assert path_comps[-2].startswith("g")
  299. groupsize = int(path_comps[-2][1:])
  300. from .quantize import WeightOnlyInt4QuantHandler
  301. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  302. model = simple_quantizer.convert_for_runtime()
  303. model = model.to(device=device, dtype=precision)
  304. logger.info(f"Restored model from checkpoint")
  305. if isinstance(model, DualARTransformer):
  306. decode_one_token = decode_one_token_ar
  307. logger.info("Using DualARTransformer")
  308. else:
  309. decode_one_token = decode_one_token_naive
  310. logger.info("Using NaiveTransformer")
  311. if compile:
  312. logger.info("Compiling function...")
  313. decode_one_token = torch.compile(
  314. decode_one_token, mode="reduce-overhead", fullgraph=True
  315. )
  316. return model.eval(), decode_one_token
  317. @dataclass
  318. class GenerateResponse:
  319. action: Literal["sample", "next"]
  320. codes: Optional[torch.Tensor] = None
  321. text: Optional[str] = None
  322. def generate_long(
  323. *,
  324. model,
  325. device: str | torch.device,
  326. decode_one_token: callable,
  327. text: str,
  328. num_samples: int = 1,
  329. max_new_tokens: int = 0,
  330. top_p: int = 0.7,
  331. repetition_penalty: float = 1.5,
  332. temperature: float = 0.7,
  333. compile: bool = False,
  334. iterative_prompt: bool = True,
  335. max_length: int = 2048,
  336. chunk_length: int = 150,
  337. prompt_text: Optional[str | list[str]] = None,
  338. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  339. ):
  340. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  341. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  342. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  343. use_prompt = prompt_text is not None and prompt_tokens is not None
  344. if use_prompt and isinstance(prompt_text, str):
  345. prompt_text = [prompt_text]
  346. prompt_tokens = [prompt_tokens]
  347. assert use_prompt is False or len(prompt_text) == len(
  348. prompt_tokens
  349. ), "Prompt text and tokens must have the same length"
  350. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  351. tokenizer = model.tokenizer
  352. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  353. encoded = []
  354. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  355. encoded_prompts = []
  356. if use_prompt:
  357. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  358. encoded_prompts.append(
  359. encode_tokens(
  360. tokenizer,
  361. string=t,
  362. device=device,
  363. prompt_tokens=c,
  364. num_codebooks=model.config.num_codebooks,
  365. )
  366. )
  367. for idx, text in enumerate(texts):
  368. encoded.append(
  369. encode_tokens(
  370. tokenizer,
  371. string=text,
  372. device=device,
  373. num_codebooks=model.config.num_codebooks,
  374. )
  375. )
  376. logger.info(f"Encoded text: {text}")
  377. # Move temperature, top_p, repetition_penalty to device
  378. # This is important so that changing params doesn't trigger recompile
  379. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  380. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  381. repetition_penalty = torch.tensor(
  382. repetition_penalty, device=device, dtype=torch.float
  383. )
  384. for sample_idx in range(num_samples):
  385. if torch.cuda.is_available():
  386. torch.cuda.synchronize()
  387. global_encoded = []
  388. seg_idx = 0
  389. while seg_idx < len(encoded):
  390. logger.info(
  391. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  392. )
  393. seg = encoded[seg_idx]
  394. global_encoded.append(seg)
  395. lengths = reversed([seg.size(1) for seg in global_encoded])
  396. # Pick last 2000 tokens
  397. count = 0
  398. for i, length in enumerate(lengths):
  399. count += length
  400. if count + length > max_length - 1024 - sum(
  401. t.shape[1] for t in encoded_prompts
  402. ):
  403. break
  404. if i != 0 and i % 2 == 0:
  405. i -= 1
  406. # Rotate the list, always make sure first segment is included to avoid drift
  407. if i < len(global_encoded) - 2:
  408. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  409. else:
  410. partial_encoded = global_encoded
  411. if use_prompt:
  412. partial_encoded = encoded_prompts + partial_encoded
  413. cat_encoded = torch.cat(partial_encoded, dim=1)
  414. prompt_length = cat_encoded.size(1)
  415. t0 = time.perf_counter()
  416. y = generate(
  417. model=model,
  418. prompt=cat_encoded,
  419. max_new_tokens=max_new_tokens,
  420. im_end_id=im_end_id,
  421. decode_one_token=decode_one_token,
  422. temperature=temperature,
  423. top_p=top_p,
  424. repetition_penalty=repetition_penalty,
  425. )
  426. if sample_idx == 0 and seg_idx == 0 and compile:
  427. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  428. if torch.cuda.is_available():
  429. torch.cuda.synchronize()
  430. t = time.perf_counter() - t0
  431. tokens_generated = y.size(1) - prompt_length
  432. tokens_sec = tokens_generated / t
  433. logger.info(
  434. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  435. )
  436. logger.info(
  437. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  438. )
  439. if torch.cuda.is_available():
  440. logger.info(
  441. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  442. )
  443. # Put the generated tokens
  444. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  445. codes = y[1:, prompt_length:-1].clone()
  446. codes = codes - 1
  447. assert (codes >= 0).all(), f"Negative code found"
  448. decoded = y[:, prompt_length:-1].clone()
  449. # But for global encoding, we should keep the <im_end> token
  450. global_encoded.append(decoded)
  451. assert (codes >= 0).all(), f"Negative code found: {codes}"
  452. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  453. seg_idx += 1
  454. # This indicates the end of the current sample
  455. yield GenerateResponse(action="next")
  456. @dataclass
  457. class WrappedGenerateResponse:
  458. status: Literal["success", "error"]
  459. response: Optional[GenerateResponse | Exception] = None
  460. @dataclass
  461. class GenerateRequest:
  462. request: dict
  463. response_queue: queue.Queue
  464. def launch_thread_safe_queue(
  465. checkpoint_path,
  466. device,
  467. precision,
  468. compile: bool = False,
  469. ):
  470. input_queue = queue.Queue()
  471. init_event = threading.Event()
  472. def worker():
  473. model, decode_one_token = load_model(
  474. checkpoint_path, device, precision, compile=compile
  475. )
  476. init_event.set()
  477. while True:
  478. item: GenerateRequest | None = input_queue.get()
  479. if item is None:
  480. break
  481. kwargs = item.request
  482. response_queue = item.response_queue
  483. try:
  484. for chunk in generate_long(
  485. model=model, decode_one_token=decode_one_token, **kwargs
  486. ):
  487. response_queue.put(
  488. WrappedGenerateResponse(status="success", response=chunk)
  489. )
  490. except Exception as e:
  491. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  492. threading.Thread(target=worker, daemon=True).start()
  493. init_event.wait()
  494. return input_queue
  495. @click.command()
  496. @click.option(
  497. "--text",
  498. type=str,
  499. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  500. )
  501. @click.option("--prompt-text", type=str, default=None, multiple=True)
  502. @click.option(
  503. "--prompt-tokens",
  504. type=click.Path(path_type=Path, exists=True),
  505. default=None,
  506. multiple=True,
  507. )
  508. @click.option("--num-samples", type=int, default=1)
  509. @click.option("--max-new-tokens", type=int, default=0)
  510. @click.option("--top-p", type=float, default=0.7)
  511. @click.option("--repetition-penalty", type=float, default=1.5)
  512. @click.option("--temperature", type=float, default=0.7)
  513. @click.option(
  514. "--checkpoint-path",
  515. type=click.Path(path_type=Path, exists=True),
  516. default="checkpoints/fish-speech-1.2",
  517. )
  518. @click.option("--compile/--no-compile", default=False)
  519. @click.option("--seed", type=int, default=42)
  520. @click.option("--half/--no-half", default=False)
  521. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  522. @click.option("--chunk-length", type=int, default=150)
  523. def main(
  524. text: str,
  525. prompt_text: Optional[list[str]],
  526. prompt_tokens: Optional[list[Path]],
  527. num_samples: int,
  528. max_new_tokens: int,
  529. top_p: int,
  530. repetition_penalty: float,
  531. temperature: float,
  532. checkpoint_path: Path,
  533. compile: bool,
  534. seed: int,
  535. half: bool,
  536. iterative_prompt: bool,
  537. chunk_length: int,
  538. ) -> None:
  539. device = "cuda"
  540. precision = torch.half if half else torch.bfloat16
  541. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  542. raise ValueError(
  543. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  544. )
  545. logger.info("Loading model ...")
  546. t0 = time.time()
  547. model, decode_one_token = load_model(
  548. checkpoint_path, device, precision, compile=compile
  549. )
  550. if torch.cuda.is_available():
  551. torch.cuda.synchronize()
  552. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  553. if prompt_tokens is not None:
  554. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  555. torch.manual_seed(seed)
  556. if torch.cuda.is_available():
  557. torch.cuda.manual_seed(seed)
  558. generator = generate_long(
  559. model=model,
  560. device=device,
  561. decode_one_token=decode_one_token,
  562. text=text,
  563. num_samples=num_samples,
  564. max_new_tokens=max_new_tokens,
  565. top_p=top_p,
  566. repetition_penalty=repetition_penalty,
  567. temperature=temperature,
  568. compile=compile,
  569. iterative_prompt=iterative_prompt,
  570. chunk_length=chunk_length,
  571. prompt_text=prompt_text,
  572. prompt_tokens=prompt_tokens,
  573. )
  574. idx = 0
  575. codes = []
  576. for response in generator:
  577. if response.action == "sample":
  578. codes.append(response.codes)
  579. logger.info(f"Sampled text: {response.text}")
  580. elif response.action == "next":
  581. if codes:
  582. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  583. logger.info(f"Saved codes to codes_{idx}.npy")
  584. logger.info(f"Next sample")
  585. codes = []
  586. idx += 1
  587. else:
  588. logger.error(f"Error: {response}")
  589. if __name__ == "__main__":
  590. main()