generate.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import hydra
  11. import numpy as np
  12. import torch
  13. import torch._dynamo.config
  14. import torch._inductor.config
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  18. from fish_speech.text import clean_text, split_text
  19. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  20. torch._inductor.config.coordinate_descent_tuning = True
  21. torch._inductor.config.triton.unique_kernel_names = True
  22. if hasattr(torch._inductor.config, "fx_graph_cache"):
  23. # Experimental feature to reduce compilation times, will be on by default in future
  24. torch._inductor.config.fx_graph_cache = True
  25. from fish_speech.models.text2semantic.llama import (
  26. BaseTransformer,
  27. DualARTransformer,
  28. NaiveTransformer,
  29. )
  30. def multinomial_sample_one_no_sync(
  31. probs_sort,
  32. ): # Does multinomial sampling without a cuda synchronization
  33. q = torch.empty_like(probs_sort).exponential_(1)
  34. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  35. def logits_to_probs(
  36. logits,
  37. previous_tokens: Optional[torch.Tensor] = None,
  38. temperature: torch.Tensor = 1.0,
  39. top_p: torch.Tensor = 1.0,
  40. repetition_penalty: torch.Tensor = 1.0,
  41. ) -> torch.Tensor:
  42. # Apply repetition penalty
  43. if previous_tokens is not None:
  44. previous_tokens = previous_tokens.long()
  45. score = torch.gather(logits, dim=0, index=previous_tokens)
  46. score = torch.where(
  47. score < 0, score * repetition_penalty, score / repetition_penalty
  48. )
  49. logits.scatter_(dim=0, index=previous_tokens, src=score)
  50. # Apply top-p sampling
  51. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  52. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  53. sorted_indices_to_remove = cum_probs > top_p
  54. sorted_indices_to_remove[0] = False # keep at least one option
  55. indices_to_remove = sorted_indices_to_remove.scatter(
  56. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  57. )
  58. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  59. logits = logits / max(temperature, 1e-5)
  60. probs = torch.nn.functional.softmax(logits, dim=-1)
  61. return probs
  62. def sample(
  63. logits,
  64. previous_tokens: Optional[torch.Tensor] = None,
  65. **sampling_kwargs,
  66. ) -> Tuple[torch.Tensor, torch.Tensor]:
  67. probs = logits_to_probs(
  68. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  69. )
  70. idx_next = multinomial_sample_one_no_sync(probs)
  71. return idx_next, probs
  72. def decode_one_token_ar(
  73. model: DualARTransformer,
  74. x: torch.Tensor,
  75. input_pos: torch.Tensor,
  76. previous_tokens: torch.Tensor = None,
  77. semantic_id: int = 0,
  78. **sampling_kwargs,
  79. ) -> torch.Tensor:
  80. x = model.forward_generate(x, input_pos)
  81. sampling_kwargs_main = sampling_kwargs.copy()
  82. # sampling_kwargs_main["temperature"] = 0.1
  83. # sampling_kwargs_main["top_p"] = 0.1
  84. # sampling_kwargs_main["repetition_penalty"] = 1.0
  85. codebooks = [
  86. sample(
  87. x.logits,
  88. previous_tokens=None, # Disable repetition penalty for the token codebook
  89. **sampling_kwargs_main,
  90. )[0]
  91. ]
  92. x = x.hidden_states
  93. # Cleanup the cache
  94. for layer in model.fast_layers:
  95. layer.attention.kv_cache.k_cache.fill_(0)
  96. layer.attention.kv_cache.v_cache.fill_(0)
  97. for codebook_idx in range(model.config.num_codebooks):
  98. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  99. logits = model.forward_generate_fast(x, input_pos)
  100. a = sample(
  101. logits,
  102. previous_tokens=(
  103. previous_tokens[codebook_idx + 1]
  104. if previous_tokens is not None
  105. else None
  106. ),
  107. **sampling_kwargs,
  108. )[0]
  109. x = model.fast_embeddings(a)
  110. codebooks.append(a)
  111. codebooks = torch.stack(codebooks, dim=0)
  112. codebooks[1:, :] = torch.masked_fill(
  113. codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
  114. )
  115. return codebooks
  116. def decode_one_token_naive(
  117. model: NaiveTransformer,
  118. x: torch.Tensor,
  119. input_pos: torch.Tensor,
  120. previous_tokens: torch.Tensor = None,
  121. **sampling_kwargs,
  122. ) -> torch.Tensor:
  123. x = model.forward_generate(x, input_pos)
  124. sampling_kwargs_main = sampling_kwargs.copy()
  125. sampling_kwargs_main["temperature"] = 0.1
  126. sampling_kwargs_main["top_p"] = 0.1
  127. sampling_kwargs_main["repetition_penalty"] = 1.0
  128. codebooks = [
  129. sample(
  130. x.logits,
  131. previous_tokens=None, # Disable repetition penalty for the token codebook
  132. **sampling_kwargs_main,
  133. )[0]
  134. ]
  135. for i in range(model.config.num_codebooks):
  136. codebooks.append(
  137. sample(
  138. x.codebook_logits[:, :, i],
  139. previous_tokens=(
  140. previous_tokens[i + 1] if previous_tokens is not None else None
  141. ),
  142. **sampling_kwargs,
  143. )[0]
  144. )
  145. return torch.stack(codebooks, dim=0)
  146. def decode_n_tokens(
  147. model: NaiveTransformer,
  148. cur_token: torch.Tensor,
  149. input_pos: torch.Tensor,
  150. num_new_tokens: int,
  151. im_end_id: int = 4,
  152. decode_one_token=decode_one_token_naive,
  153. semantic_id: int = 0,
  154. **sampling_kwargs,
  155. ):
  156. previous_tokens = torch.zeros(
  157. (model.config.num_codebooks + 1, model.config.max_seq_len),
  158. dtype=torch.int,
  159. device=cur_token.device,
  160. )
  161. for i in tqdm(range(num_new_tokens)):
  162. # We need to get windowed repeat penalty
  163. win_size = 16
  164. if i < win_size:
  165. window = previous_tokens[:, :win_size]
  166. else:
  167. window = previous_tokens[:, i - win_size : i]
  168. with (
  169. torch.backends.cuda.sdp_kernel(
  170. enable_flash=False, enable_mem_efficient=False, enable_math=True
  171. )
  172. if torch.cuda.is_available()
  173. else nullcontext()
  174. ): # Actually better for Inductor to codegen attention here
  175. next_token = decode_one_token(
  176. model=model,
  177. x=cur_token,
  178. input_pos=input_pos,
  179. previous_tokens=window,
  180. semantic_id=semantic_id,
  181. **sampling_kwargs,
  182. )
  183. input_pos += 1
  184. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  185. previous_tokens[:, i : i + 1] = next_token.view(
  186. model.config.num_codebooks + 1, -1
  187. )
  188. if cur_token[0, 0, -1] == im_end_id:
  189. break
  190. return previous_tokens[:, : i + 1]
  191. @torch.no_grad()
  192. @torch.inference_mode()
  193. def generate(
  194. *,
  195. model: NaiveTransformer,
  196. prompt: torch.Tensor,
  197. max_new_tokens: int,
  198. im_end_id: int = 4,
  199. decode_one_token=decode_one_token_naive,
  200. **sampling_kwargs,
  201. ) -> torch.Tensor:
  202. """
  203. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  204. """
  205. # create an empty tensor of the expected final shape and fill in the current tokens
  206. T = prompt.size(1)
  207. semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
  208. if max_new_tokens:
  209. if T + max_new_tokens > model.config.max_seq_len:
  210. max_new_tokens = model.config.max_seq_len - T
  211. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  212. T_new = T + max_new_tokens
  213. else:
  214. T_new = model.config.max_seq_len
  215. max_new_tokens = T_new - T
  216. device, dtype = prompt.device, prompt.dtype
  217. codebook_dim = 1 + model.config.num_codebooks
  218. # create an empty tensor of the expected final shape and fill in the current tokens
  219. empty = torch.empty(
  220. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  221. )
  222. empty[:, :T] = prompt
  223. seq = empty
  224. input_pos = torch.arange(0, T, device=device)
  225. # Use non-accelerated version for now, to avoid compilation overhead
  226. prefill_decode = (
  227. decode_one_token_naive
  228. if isinstance(model, NaiveTransformer)
  229. else decode_one_token_ar
  230. )
  231. next_token = prefill_decode(
  232. model,
  233. prompt.view(1, codebook_dim, -1),
  234. input_pos,
  235. semantic_id=semantic_id,
  236. **sampling_kwargs,
  237. )
  238. seq[:, T : T + 1] = next_token
  239. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  240. x = decode_n_tokens(
  241. model,
  242. next_token.view(1, codebook_dim, -1),
  243. input_pos,
  244. max_new_tokens - 1,
  245. im_end_id=im_end_id,
  246. decode_one_token=decode_one_token,
  247. semantic_id=semantic_id,
  248. **sampling_kwargs,
  249. )
  250. # x = torch.cat(generated_tokens, dim=1)
  251. seq = seq[:, : T + 1 + x.size(1)]
  252. seq[:, T + 1 :] = x
  253. return seq
  254. def encode_tokens(
  255. tokenizer,
  256. string,
  257. device="cuda",
  258. prompt_tokens=None,
  259. num_codebooks=4,
  260. ):
  261. string = clean_text(string)
  262. string = f"<|im_start|>user\nSpeak: {string}<|im_end|><|im_start|>assistant\n"
  263. new_tokens = tokenizer.encode(
  264. string,
  265. add_special_tokens=False,
  266. max_length=10**6,
  267. truncation=False,
  268. )
  269. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  270. # Codebooks
  271. zeros = (
  272. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  273. * CODEBOOK_PAD_TOKEN_ID
  274. )
  275. prompt = torch.cat((tokens, zeros), dim=0)
  276. if prompt_tokens is None:
  277. return prompt
  278. # Get prompt tokens
  279. if prompt_tokens.ndim == 3:
  280. assert (
  281. prompt_tokens.shape[0] == 1
  282. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  283. prompt_tokens = prompt_tokens[0]
  284. assert prompt_tokens.ndim == 2
  285. data = prompt_tokens + 1
  286. if prompt_tokens.shape[0] > num_codebooks:
  287. logger.warning(
  288. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  289. )
  290. data = data[:num_codebooks]
  291. # Add pad token for each codebook
  292. data = torch.cat(
  293. (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
  294. dim=1,
  295. )
  296. # Since 1.0, we use <|semantic|>
  297. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  298. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  299. main_token_ids = (
  300. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  301. )
  302. main_token_ids[0, -1] = end_token_id
  303. data = torch.cat((main_token_ids, data), dim=0)
  304. prompt = torch.cat((prompt, data), dim=1)
  305. return prompt
  306. def load_model(checkpoint_path, device, precision, compile=False):
  307. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  308. checkpoint_path, load_weights=True
  309. )
  310. model = model.to(device=device, dtype=precision)
  311. logger.info(f"Restored model from checkpoint")
  312. if isinstance(model, DualARTransformer):
  313. decode_one_token = decode_one_token_ar
  314. logger.info("Using DualARTransformer")
  315. else:
  316. decode_one_token = decode_one_token_naive
  317. logger.info("Using NaiveTransformer")
  318. if compile:
  319. logger.info("Compiling function...")
  320. decode_one_token = torch.compile(
  321. decode_one_token,
  322. fullgraph=True,
  323. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  324. mode="reduce-overhead" if torch.cuda.is_available() else None,
  325. )
  326. return model.eval(), decode_one_token
  327. @dataclass
  328. class GenerateResponse:
  329. action: Literal["sample", "next"]
  330. codes: Optional[torch.Tensor] = None
  331. text: Optional[str] = None
  332. def generate_long(
  333. *,
  334. model,
  335. device: str | torch.device,
  336. decode_one_token: callable,
  337. text: str,
  338. num_samples: int = 1,
  339. max_new_tokens: int = 0,
  340. top_p: int = 0.7,
  341. repetition_penalty: float = 1.5,
  342. temperature: float = 0.7,
  343. compile: bool = False,
  344. iterative_prompt: bool = True,
  345. max_length: int = 2048,
  346. chunk_length: int = 150,
  347. prompt_text: Optional[str | list[str]] = None,
  348. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  349. ):
  350. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  351. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  352. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  353. use_prompt = prompt_text is not None and prompt_tokens is not None
  354. if use_prompt and isinstance(prompt_text, str):
  355. prompt_text = [prompt_text]
  356. prompt_tokens = [prompt_tokens]
  357. assert use_prompt is False or len(prompt_text) == len(
  358. prompt_tokens
  359. ), "Prompt text and tokens must have the same length"
  360. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  361. tokenizer = model.tokenizer
  362. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  363. encoded = []
  364. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  365. encoded_prompts = []
  366. if use_prompt:
  367. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  368. encoded_prompts.append(
  369. encode_tokens(
  370. tokenizer,
  371. string=t,
  372. device=device,
  373. prompt_tokens=c,
  374. num_codebooks=model.config.num_codebooks,
  375. )
  376. )
  377. for idx, text in enumerate(texts):
  378. encoded.append(
  379. encode_tokens(
  380. tokenizer,
  381. string=text,
  382. device=device,
  383. num_codebooks=model.config.num_codebooks,
  384. )
  385. )
  386. logger.info(f"Encoded text: {text}")
  387. # Move temperature, top_p, repetition_penalty to device
  388. # This is important so that changing params doesn't trigger recompile
  389. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  390. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  391. repetition_penalty = torch.tensor(
  392. repetition_penalty, device=device, dtype=torch.float
  393. )
  394. for sample_idx in range(num_samples):
  395. if torch.cuda.is_available():
  396. torch.cuda.synchronize()
  397. global_encoded = []
  398. seg_idx = 0
  399. while seg_idx < len(encoded):
  400. logger.info(
  401. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  402. )
  403. seg = encoded[seg_idx]
  404. global_encoded.append(seg)
  405. lengths = reversed([seg.size(1) for seg in global_encoded])
  406. # Pick last 2000 tokens
  407. count = 0
  408. for i, length in enumerate(lengths):
  409. count += length
  410. if count + length > max_length - 1024 - sum(
  411. t.shape[1] for t in encoded_prompts
  412. ):
  413. break
  414. if i != 0 and i % 2 == 0:
  415. i -= 1
  416. # Rotate the list, always make sure first segment is included to avoid drift
  417. if i < len(global_encoded) - 2:
  418. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  419. else:
  420. partial_encoded = global_encoded
  421. if use_prompt:
  422. partial_encoded = encoded_prompts + partial_encoded
  423. cat_encoded = torch.cat(partial_encoded, dim=1)
  424. prompt_length = cat_encoded.size(1)
  425. t0 = time.perf_counter()
  426. y = generate(
  427. model=model,
  428. prompt=cat_encoded,
  429. max_new_tokens=max_new_tokens,
  430. im_end_id=im_end_id,
  431. decode_one_token=decode_one_token,
  432. temperature=temperature,
  433. top_p=top_p,
  434. repetition_penalty=repetition_penalty,
  435. )
  436. if sample_idx == 0 and seg_idx == 0 and compile:
  437. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  438. if torch.cuda.is_available():
  439. torch.cuda.synchronize()
  440. t = time.perf_counter() - t0
  441. tokens_generated = y.size(1) - prompt_length
  442. tokens_sec = tokens_generated / t
  443. logger.info(
  444. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  445. )
  446. logger.info(
  447. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  448. )
  449. if torch.cuda.is_available():
  450. logger.info(
  451. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  452. )
  453. # Put the generated tokens
  454. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  455. codes = y[1:, prompt_length:-1].clone()
  456. codes = codes - 1
  457. assert (codes >= 0).all(), f"Negative code found"
  458. decoded = y[:, prompt_length:-1].clone()
  459. # But for global encoding, we should keep the <im_end> token
  460. global_encoded.append(decoded)
  461. assert (codes >= 0).all(), f"Negative code found: {codes}"
  462. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  463. seg_idx += 1
  464. # This indicates the end of the current sample
  465. yield GenerateResponse(action="next")
  466. @dataclass
  467. class WrappedGenerateResponse:
  468. status: Literal["success", "error"]
  469. response: Optional[GenerateResponse | Exception] = None
  470. @dataclass
  471. class GenerateRequest:
  472. request: dict
  473. response_queue: queue.Queue
  474. def launch_thread_safe_queue(
  475. checkpoint_path,
  476. device,
  477. precision,
  478. compile: bool = False,
  479. ):
  480. input_queue = queue.Queue()
  481. init_event = threading.Event()
  482. def worker():
  483. model, decode_one_token = load_model(
  484. checkpoint_path, device, precision, compile=compile
  485. )
  486. with torch.device(device):
  487. model.setup_caches(
  488. max_batch_size=1,
  489. max_seq_len=model.config.max_seq_len,
  490. dtype=next(model.parameters()).dtype,
  491. )
  492. init_event.set()
  493. while True:
  494. item: GenerateRequest | None = input_queue.get()
  495. if item is None:
  496. break
  497. kwargs = item.request
  498. response_queue = item.response_queue
  499. try:
  500. for chunk in generate_long(
  501. model=model, decode_one_token=decode_one_token, **kwargs
  502. ):
  503. response_queue.put(
  504. WrappedGenerateResponse(status="success", response=chunk)
  505. )
  506. except Exception as e:
  507. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  508. threading.Thread(target=worker, daemon=True).start()
  509. init_event.wait()
  510. return input_queue
  511. @click.command()
  512. @click.option(
  513. "--text",
  514. type=str,
  515. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  516. )
  517. @click.option("--prompt-text", type=str, default=None, multiple=True)
  518. @click.option(
  519. "--prompt-tokens",
  520. type=click.Path(path_type=Path, exists=True),
  521. default=None,
  522. multiple=True,
  523. )
  524. @click.option("--num-samples", type=int, default=1)
  525. @click.option("--max-new-tokens", type=int, default=0)
  526. @click.option("--top-p", type=float, default=0.7)
  527. @click.option("--repetition-penalty", type=float, default=1.2)
  528. @click.option("--temperature", type=float, default=0.7)
  529. @click.option(
  530. "--checkpoint-path",
  531. type=click.Path(path_type=Path, exists=True),
  532. default="checkpoints/fish-speech-1.4",
  533. )
  534. @click.option("--device", type=str, default="cuda")
  535. @click.option("--compile/--no-compile", default=False)
  536. @click.option("--seed", type=int, default=42)
  537. @click.option("--half/--no-half", default=False)
  538. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  539. @click.option("--chunk-length", type=int, default=100)
  540. def main(
  541. text: str,
  542. prompt_text: Optional[list[str]],
  543. prompt_tokens: Optional[list[Path]],
  544. num_samples: int,
  545. max_new_tokens: int,
  546. top_p: int,
  547. repetition_penalty: float,
  548. temperature: float,
  549. checkpoint_path: Path,
  550. device: str,
  551. compile: bool,
  552. seed: int,
  553. half: bool,
  554. iterative_prompt: bool,
  555. chunk_length: int,
  556. ) -> None:
  557. precision = torch.half if half else torch.bfloat16
  558. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  559. raise ValueError(
  560. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  561. )
  562. logger.info("Loading model ...")
  563. t0 = time.time()
  564. model, decode_one_token = load_model(
  565. checkpoint_path, device, precision, compile=compile
  566. )
  567. with torch.device(device):
  568. model.setup_caches(
  569. max_batch_size=1,
  570. max_seq_len=model.config.max_seq_len,
  571. dtype=next(model.parameters()).dtype,
  572. )
  573. if torch.cuda.is_available():
  574. torch.cuda.synchronize()
  575. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  576. if prompt_tokens is not None:
  577. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  578. torch.manual_seed(seed)
  579. if torch.cuda.is_available():
  580. torch.cuda.manual_seed(seed)
  581. generator = generate_long(
  582. model=model,
  583. device=device,
  584. decode_one_token=decode_one_token,
  585. text=text,
  586. num_samples=num_samples,
  587. max_new_tokens=max_new_tokens,
  588. top_p=top_p,
  589. repetition_penalty=repetition_penalty,
  590. temperature=temperature,
  591. compile=compile,
  592. iterative_prompt=iterative_prompt,
  593. chunk_length=chunk_length,
  594. prompt_text=prompt_text,
  595. prompt_tokens=prompt_tokens,
  596. )
  597. idx = 0
  598. codes = []
  599. for response in generator:
  600. if response.action == "sample":
  601. codes.append(response.codes)
  602. logger.info(f"Sampled text: {response.text}")
  603. elif response.action == "next":
  604. if codes:
  605. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  606. logger.info(f"Saved codes to codes_{idx}.npy")
  607. logger.info(f"Next sample")
  608. codes = []
  609. idx += 1
  610. else:
  611. logger.error(f"Error: {response}")
  612. if __name__ == "__main__":
  613. main()