generate.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import hydra
  11. import numpy as np
  12. import torch
  13. import torch._dynamo.config
  14. import torch._inductor.config
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  18. from fish_speech.text import clean_text, split_text
  19. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  20. torch._inductor.config.coordinate_descent_tuning = True
  21. torch._inductor.config.triton.unique_kernel_names = True
  22. if hasattr(torch._inductor.config, "fx_graph_cache"):
  23. # Experimental feature to reduce compilation times, will be on by default in future
  24. torch._inductor.config.fx_graph_cache = True
  25. from fish_speech.models.text2semantic.llama import (
  26. BaseTransformer,
  27. DualARTransformer,
  28. NaiveTransformer,
  29. )
  30. def multinomial_sample_one_no_sync(
  31. probs_sort,
  32. ): # Does multinomial sampling without a cuda synchronization
  33. q = torch.empty_like(probs_sort).exponential_(1)
  34. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  35. def logits_to_probs(
  36. logits,
  37. previous_tokens: Optional[torch.Tensor] = None,
  38. temperature: torch.Tensor = 1.0,
  39. top_p: torch.Tensor = 1.0,
  40. repetition_penalty: torch.Tensor = 1.0,
  41. ) -> torch.Tensor:
  42. # Apply repetition penalty
  43. if previous_tokens is not None:
  44. previous_tokens = previous_tokens.long()
  45. score = torch.gather(logits, dim=0, index=previous_tokens)
  46. score = torch.where(
  47. score < 0, score * repetition_penalty, score / repetition_penalty
  48. )
  49. logits.scatter_(dim=0, index=previous_tokens, src=score)
  50. # Apply top-p sampling
  51. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  52. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  53. sorted_indices_to_remove = cum_probs > top_p
  54. sorted_indices_to_remove[0] = False # keep at least one option
  55. indices_to_remove = sorted_indices_to_remove.scatter(
  56. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  57. )
  58. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  59. logits = logits / max(temperature, 1e-5)
  60. probs = torch.nn.functional.softmax(logits, dim=-1)
  61. return probs
  62. def sample(
  63. logits,
  64. previous_tokens: Optional[torch.Tensor] = None,
  65. **sampling_kwargs,
  66. ) -> Tuple[torch.Tensor, torch.Tensor]:
  67. probs = logits_to_probs(
  68. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  69. )
  70. idx_next = multinomial_sample_one_no_sync(probs)
  71. return idx_next, probs
  72. def decode_one_token_ar(
  73. model: DualARTransformer,
  74. x: torch.Tensor,
  75. input_pos: torch.Tensor,
  76. previous_tokens: torch.Tensor = None,
  77. **sampling_kwargs,
  78. ) -> torch.Tensor:
  79. x = model.forward_generate(x, input_pos)
  80. sampling_kwargs_main = sampling_kwargs.copy()
  81. sampling_kwargs_main["temperature"] = 0.1
  82. sampling_kwargs_main["top_p"] = 0.1
  83. sampling_kwargs_main["repetition_penalty"] = 1.0
  84. codebooks = [
  85. sample(
  86. x.logits,
  87. previous_tokens=None, # Disable repetition penalty for the token codebook
  88. **sampling_kwargs_main,
  89. )[0]
  90. ]
  91. x = x.hidden_states
  92. # Cleanup the cache
  93. for layer in model.fast_layers:
  94. layer.attention.kv_cache.k_cache.fill_(0)
  95. layer.attention.kv_cache.v_cache.fill_(0)
  96. for codebook_idx in range(model.config.num_codebooks):
  97. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  98. logits = model.forward_generate_fast(x, input_pos)
  99. a = sample(
  100. logits,
  101. previous_tokens=(
  102. previous_tokens[codebook_idx + 1]
  103. if previous_tokens is not None
  104. else None
  105. ),
  106. **sampling_kwargs,
  107. )[0]
  108. x = model.fast_embeddings(a)
  109. codebooks.append(a)
  110. return torch.stack(codebooks, dim=0)
  111. def decode_one_token_naive(
  112. model: NaiveTransformer,
  113. x: torch.Tensor,
  114. input_pos: torch.Tensor,
  115. previous_tokens: torch.Tensor = None,
  116. **sampling_kwargs,
  117. ) -> torch.Tensor:
  118. x = model.forward_generate(x, input_pos)
  119. sampling_kwargs_main = sampling_kwargs.copy()
  120. sampling_kwargs_main["temperature"] = 0.1
  121. sampling_kwargs_main["top_p"] = 0.1
  122. sampling_kwargs_main["repetition_penalty"] = 1.0
  123. codebooks = [
  124. sample(
  125. x.logits,
  126. previous_tokens=None, # Disable repetition penalty for the token codebook
  127. **sampling_kwargs_main,
  128. )[0]
  129. ]
  130. for i in range(model.config.num_codebooks):
  131. codebooks.append(
  132. sample(
  133. x.codebook_logits[:, :, i],
  134. previous_tokens=(
  135. previous_tokens[i + 1] if previous_tokens is not None else None
  136. ),
  137. **sampling_kwargs,
  138. )[0]
  139. )
  140. return torch.stack(codebooks, dim=0)
  141. def decode_n_tokens(
  142. model: NaiveTransformer,
  143. cur_token: torch.Tensor,
  144. input_pos: torch.Tensor,
  145. num_new_tokens: int,
  146. im_end_id: int = 4,
  147. decode_one_token=decode_one_token_naive,
  148. **sampling_kwargs,
  149. ):
  150. previous_tokens = torch.zeros(
  151. (model.config.num_codebooks + 1, model.config.max_seq_len),
  152. dtype=torch.int,
  153. device=cur_token.device,
  154. )
  155. for i in tqdm(range(num_new_tokens)):
  156. # We need to get windowed repeat penalty
  157. win_size = 16
  158. if i < win_size:
  159. window = previous_tokens[:, :win_size]
  160. else:
  161. window = previous_tokens[:, i - win_size : i]
  162. with (
  163. torch.backends.cuda.sdp_kernel(
  164. enable_flash=False, enable_mem_efficient=False, enable_math=True
  165. )
  166. if torch.cuda.is_available()
  167. else nullcontext()
  168. ): # Actually better for Inductor to codegen attention here
  169. next_token = decode_one_token(
  170. model=model,
  171. x=cur_token,
  172. input_pos=input_pos,
  173. previous_tokens=window,
  174. **sampling_kwargs,
  175. )
  176. input_pos += 1
  177. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  178. previous_tokens[:, i : i + 1] = next_token.view(
  179. model.config.num_codebooks + 1, -1
  180. )
  181. if cur_token[0, 0, -1] == im_end_id:
  182. break
  183. return previous_tokens[:, : i + 1]
  184. @torch.no_grad()
  185. @torch.inference_mode()
  186. def generate(
  187. *,
  188. model: NaiveTransformer,
  189. prompt: torch.Tensor,
  190. max_new_tokens: int,
  191. im_end_id: int = 4,
  192. decode_one_token=decode_one_token_naive,
  193. **sampling_kwargs,
  194. ) -> torch.Tensor:
  195. """
  196. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  197. """
  198. # create an empty tensor of the expected final shape and fill in the current tokens
  199. T = prompt.size(1)
  200. if max_new_tokens:
  201. if T + max_new_tokens > model.config.max_seq_len:
  202. max_new_tokens = model.config.max_seq_len - T
  203. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  204. T_new = T + max_new_tokens
  205. else:
  206. T_new = model.config.max_seq_len
  207. max_new_tokens = T_new - T
  208. device, dtype = prompt.device, prompt.dtype
  209. codebook_dim = 1 + model.config.num_codebooks
  210. # create an empty tensor of the expected final shape and fill in the current tokens
  211. empty = torch.empty(
  212. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  213. )
  214. empty[:, :T] = prompt
  215. seq = empty
  216. input_pos = torch.arange(0, T, device=device)
  217. # Use non-accelerated version for now, to avoid compilation overhead
  218. prefill_decode = (
  219. decode_one_token_naive
  220. if isinstance(model, NaiveTransformer)
  221. else decode_one_token_ar
  222. )
  223. next_token = prefill_decode(
  224. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  225. )
  226. seq[:, T : T + 1] = next_token
  227. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  228. x = decode_n_tokens(
  229. model,
  230. next_token.view(1, codebook_dim, -1),
  231. input_pos,
  232. max_new_tokens - 1,
  233. im_end_id=im_end_id,
  234. decode_one_token=decode_one_token,
  235. **sampling_kwargs,
  236. )
  237. # x = torch.cat(generated_tokens, dim=1)
  238. seq = seq[:, : T + 1 + x.size(1)]
  239. seq[:, T + 1 :] = x
  240. return seq
  241. def encode_tokens(
  242. tokenizer,
  243. string,
  244. device="cuda",
  245. prompt_tokens=None,
  246. num_codebooks=4,
  247. ):
  248. string = clean_text(string)
  249. string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
  250. new_tokens = tokenizer.encode(
  251. string,
  252. add_special_tokens=False,
  253. max_length=10**6,
  254. truncation=False,
  255. )
  256. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  257. # Codebooks
  258. zeros = (
  259. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  260. * CODEBOOK_PAD_TOKEN_ID
  261. )
  262. prompt = torch.cat((tokens, zeros), dim=0)
  263. if prompt_tokens is None:
  264. return prompt
  265. # Get prompt tokens
  266. if prompt_tokens.ndim == 3:
  267. assert (
  268. prompt_tokens.shape[0] == 1
  269. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  270. prompt_tokens = prompt_tokens[0]
  271. assert prompt_tokens.ndim == 2
  272. data = prompt_tokens + 1
  273. if prompt_tokens.shape[0] > num_codebooks:
  274. logger.warning(
  275. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  276. )
  277. data = data[:num_codebooks]
  278. # Add pad token for each codebook
  279. data = torch.cat(
  280. (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
  281. dim=1,
  282. )
  283. # Since 1.0, we use <|semantic|>
  284. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  285. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  286. main_token_ids = (
  287. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  288. )
  289. main_token_ids[0, -1] = end_token_id
  290. data = torch.cat((main_token_ids, data), dim=0)
  291. prompt = torch.cat((prompt, data), dim=1)
  292. return prompt
  293. def load_model(checkpoint_path, device, precision, compile=False):
  294. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  295. checkpoint_path, load_weights=True
  296. )
  297. model = model.to(device=device, dtype=precision)
  298. logger.info(f"Restored model from checkpoint")
  299. if isinstance(model, DualARTransformer):
  300. decode_one_token = decode_one_token_ar
  301. logger.info("Using DualARTransformer")
  302. else:
  303. decode_one_token = decode_one_token_naive
  304. logger.info("Using NaiveTransformer")
  305. if compile:
  306. logger.info("Compiling function...")
  307. decode_one_token = torch.compile(
  308. decode_one_token,
  309. fullgraph=True,
  310. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  311. mode="reduce-overhead" if torch.cuda.is_available() else None,
  312. )
  313. return model.eval(), decode_one_token
  314. @dataclass
  315. class GenerateResponse:
  316. action: Literal["sample", "next"]
  317. codes: Optional[torch.Tensor] = None
  318. text: Optional[str] = None
  319. def generate_long(
  320. *,
  321. model,
  322. device: str | torch.device,
  323. decode_one_token: callable,
  324. text: str,
  325. num_samples: int = 1,
  326. max_new_tokens: int = 0,
  327. top_p: int = 0.7,
  328. repetition_penalty: float = 1.5,
  329. temperature: float = 0.7,
  330. compile: bool = False,
  331. iterative_prompt: bool = True,
  332. max_length: int = 2048,
  333. chunk_length: int = 150,
  334. prompt_text: Optional[str | list[str]] = None,
  335. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  336. ):
  337. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  338. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  339. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  340. use_prompt = prompt_text is not None and prompt_tokens is not None
  341. if use_prompt and isinstance(prompt_text, str):
  342. prompt_text = [prompt_text]
  343. prompt_tokens = [prompt_tokens]
  344. assert use_prompt is False or len(prompt_text) == len(
  345. prompt_tokens
  346. ), "Prompt text and tokens must have the same length"
  347. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  348. tokenizer = model.tokenizer
  349. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  350. encoded = []
  351. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  352. encoded_prompts = []
  353. if use_prompt:
  354. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  355. encoded_prompts.append(
  356. encode_tokens(
  357. tokenizer,
  358. string=t,
  359. device=device,
  360. prompt_tokens=c,
  361. num_codebooks=model.config.num_codebooks,
  362. )
  363. )
  364. for idx, text in enumerate(texts):
  365. encoded.append(
  366. encode_tokens(
  367. tokenizer,
  368. string=text,
  369. device=device,
  370. num_codebooks=model.config.num_codebooks,
  371. )
  372. )
  373. logger.info(f"Encoded text: {text}")
  374. # Move temperature, top_p, repetition_penalty to device
  375. # This is important so that changing params doesn't trigger recompile
  376. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  377. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  378. repetition_penalty = torch.tensor(
  379. repetition_penalty, device=device, dtype=torch.float
  380. )
  381. for sample_idx in range(num_samples):
  382. if torch.cuda.is_available():
  383. torch.cuda.synchronize()
  384. global_encoded = []
  385. seg_idx = 0
  386. while seg_idx < len(encoded):
  387. logger.info(
  388. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  389. )
  390. seg = encoded[seg_idx]
  391. global_encoded.append(seg)
  392. lengths = reversed([seg.size(1) for seg in global_encoded])
  393. # Pick last 2000 tokens
  394. count = 0
  395. for i, length in enumerate(lengths):
  396. count += length
  397. if count + length > max_length - 1024 - sum(
  398. t.shape[1] for t in encoded_prompts
  399. ):
  400. break
  401. if i != 0 and i % 2 == 0:
  402. i -= 1
  403. # Rotate the list, always make sure first segment is included to avoid drift
  404. if i < len(global_encoded) - 2:
  405. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  406. else:
  407. partial_encoded = global_encoded
  408. if use_prompt:
  409. partial_encoded = encoded_prompts + partial_encoded
  410. cat_encoded = torch.cat(partial_encoded, dim=1)
  411. prompt_length = cat_encoded.size(1)
  412. t0 = time.perf_counter()
  413. y = generate(
  414. model=model,
  415. prompt=cat_encoded,
  416. max_new_tokens=max_new_tokens,
  417. im_end_id=im_end_id,
  418. decode_one_token=decode_one_token,
  419. temperature=temperature,
  420. top_p=top_p,
  421. repetition_penalty=repetition_penalty,
  422. )
  423. if sample_idx == 0 and seg_idx == 0 and compile:
  424. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  425. if torch.cuda.is_available():
  426. torch.cuda.synchronize()
  427. t = time.perf_counter() - t0
  428. tokens_generated = y.size(1) - prompt_length
  429. tokens_sec = tokens_generated / t
  430. logger.info(
  431. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  432. )
  433. logger.info(
  434. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  435. )
  436. if torch.cuda.is_available():
  437. logger.info(
  438. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  439. )
  440. # Put the generated tokens
  441. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  442. codes = y[1:, prompt_length:-1].clone()
  443. codes = codes - 1
  444. assert (codes >= 0).all(), f"Negative code found"
  445. decoded = y[:, prompt_length:-1].clone()
  446. # But for global encoding, we should keep the <im_end> token
  447. global_encoded.append(decoded)
  448. assert (codes >= 0).all(), f"Negative code found: {codes}"
  449. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  450. seg_idx += 1
  451. # This indicates the end of the current sample
  452. yield GenerateResponse(action="next")
  453. @dataclass
  454. class WrappedGenerateResponse:
  455. status: Literal["success", "error"]
  456. response: Optional[GenerateResponse | Exception] = None
  457. @dataclass
  458. class GenerateRequest:
  459. request: dict
  460. response_queue: queue.Queue
  461. def launch_thread_safe_queue(
  462. checkpoint_path,
  463. device,
  464. precision,
  465. compile: bool = False,
  466. ):
  467. input_queue = queue.Queue()
  468. init_event = threading.Event()
  469. def worker():
  470. model, decode_one_token = load_model(
  471. checkpoint_path, device, precision, compile=compile
  472. )
  473. with torch.device(device):
  474. model.setup_caches(
  475. max_batch_size=1,
  476. max_seq_len=model.config.max_seq_len,
  477. dtype=next(model.parameters()).dtype,
  478. )
  479. init_event.set()
  480. while True:
  481. item: GenerateRequest | None = input_queue.get()
  482. if item is None:
  483. break
  484. kwargs = item.request
  485. response_queue = item.response_queue
  486. try:
  487. for chunk in generate_long(
  488. model=model, decode_one_token=decode_one_token, **kwargs
  489. ):
  490. response_queue.put(
  491. WrappedGenerateResponse(status="success", response=chunk)
  492. )
  493. except Exception as e:
  494. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  495. threading.Thread(target=worker, daemon=True).start()
  496. init_event.wait()
  497. return input_queue
  498. @click.command()
  499. @click.option(
  500. "--text",
  501. type=str,
  502. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  503. )
  504. @click.option("--prompt-text", type=str, default=None, multiple=True)
  505. @click.option(
  506. "--prompt-tokens",
  507. type=click.Path(path_type=Path, exists=True),
  508. default=None,
  509. multiple=True,
  510. )
  511. @click.option("--num-samples", type=int, default=1)
  512. @click.option("--max-new-tokens", type=int, default=0)
  513. @click.option("--top-p", type=float, default=0.7)
  514. @click.option("--repetition-penalty", type=float, default=1.2)
  515. @click.option("--temperature", type=float, default=0.7)
  516. @click.option(
  517. "--checkpoint-path",
  518. type=click.Path(path_type=Path, exists=True),
  519. default="checkpoints/fish-speech-1.4",
  520. )
  521. @click.option("--device", type=str, default="cuda")
  522. @click.option("--compile/--no-compile", default=False)
  523. @click.option("--seed", type=int, default=42)
  524. @click.option("--half/--no-half", default=False)
  525. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  526. @click.option("--chunk-length", type=int, default=100)
  527. def main(
  528. text: str,
  529. prompt_text: Optional[list[str]],
  530. prompt_tokens: Optional[list[Path]],
  531. num_samples: int,
  532. max_new_tokens: int,
  533. top_p: int,
  534. repetition_penalty: float,
  535. temperature: float,
  536. checkpoint_path: Path,
  537. device: str,
  538. compile: bool,
  539. seed: int,
  540. half: bool,
  541. iterative_prompt: bool,
  542. chunk_length: int,
  543. ) -> None:
  544. precision = torch.half if half else torch.bfloat16
  545. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  546. raise ValueError(
  547. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  548. )
  549. logger.info("Loading model ...")
  550. t0 = time.time()
  551. model, decode_one_token = load_model(
  552. checkpoint_path, device, precision, compile=compile
  553. )
  554. with torch.device(device):
  555. model.setup_caches(
  556. max_batch_size=1,
  557. max_seq_len=model.config.max_seq_len,
  558. dtype=next(model.parameters()).dtype,
  559. )
  560. if torch.cuda.is_available():
  561. torch.cuda.synchronize()
  562. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  563. if prompt_tokens is not None:
  564. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  565. torch.manual_seed(seed)
  566. if torch.cuda.is_available():
  567. torch.cuda.manual_seed(seed)
  568. generator = generate_long(
  569. model=model,
  570. device=device,
  571. decode_one_token=decode_one_token,
  572. text=text,
  573. num_samples=num_samples,
  574. max_new_tokens=max_new_tokens,
  575. top_p=top_p,
  576. repetition_penalty=repetition_penalty,
  577. temperature=temperature,
  578. compile=compile,
  579. iterative_prompt=iterative_prompt,
  580. chunk_length=chunk_length,
  581. prompt_text=prompt_text,
  582. prompt_tokens=prompt_tokens,
  583. )
  584. idx = 0
  585. codes = []
  586. for response in generator:
  587. if response.action == "sample":
  588. codes.append(response.codes)
  589. logger.info(f"Sampled text: {response.text}")
  590. elif response.action == "next":
  591. if codes:
  592. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  593. logger.info(f"Saved codes to codes_{idx}.npy")
  594. logger.info(f"Next sample")
  595. codes = []
  596. idx += 1
  597. else:
  598. logger.error(f"Error: {response}")
  599. if __name__ == "__main__":
  600. main()