inference.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import numpy as np
  11. import torch
  12. import torch._dynamo.config
  13. import torch._inductor.config
  14. from loguru import logger
  15. from tqdm import tqdm
  16. from transformers import AutoTokenizer
  17. from fish_speech.content_sequence import (
  18. ContentSequence,
  19. TextPart,
  20. VQPart,
  21. )
  22. from fish_speech.models.text2semantic.llama import BaseModelArgs
  23. from fish_speech.text import clean_text, split_text
  24. from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
  25. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  26. torch._inductor.config.coordinate_descent_tuning = True
  27. torch._inductor.config.triton.unique_kernel_names = True
  28. if hasattr(torch._inductor.config, "fx_graph_cache"):
  29. # Experimental feature to reduce compilation times, will be on by default in future
  30. torch._inductor.config.fx_graph_cache = True
  31. from torch.nn.attention import SDPBackend, sdpa_kernel
  32. from fish_speech.models.text2semantic.llama import (
  33. BaseTransformer,
  34. DualARTransformer,
  35. NaiveTransformer,
  36. )
  37. def multinomial_sample_one_no_sync(
  38. probs_sort,
  39. ): # Does multinomial sampling without a cuda synchronization
  40. q = torch.empty_like(probs_sort).exponential_(1)
  41. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  42. def logits_to_probs(
  43. logits,
  44. previous_tokens: Optional[torch.Tensor] = None,
  45. temperature: torch.Tensor = 1.0,
  46. top_p: torch.Tensor = 1.0,
  47. repetition_penalty: torch.Tensor = 1.0,
  48. ) -> torch.Tensor:
  49. # Apply repetition penalty
  50. if previous_tokens is not None:
  51. previous_tokens = previous_tokens.long()
  52. score = torch.gather(logits, dim=0, index=previous_tokens)
  53. score = torch.where(
  54. score < 0, score * repetition_penalty, score / repetition_penalty
  55. )
  56. logits.scatter_(dim=0, index=previous_tokens, src=score)
  57. # Apply top-p sampling
  58. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  59. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  60. sorted_indices_to_remove = cum_probs > top_p
  61. sorted_indices_to_remove[0] = False # keep at least one option
  62. indices_to_remove = sorted_indices_to_remove.scatter(
  63. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  64. )
  65. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  66. logits = logits / max(temperature, 1e-5)
  67. probs = torch.nn.functional.softmax(logits, dim=-1)
  68. return probs
  69. def sample(
  70. logits,
  71. previous_tokens: Optional[torch.Tensor] = None,
  72. **sampling_kwargs,
  73. ) -> Tuple[torch.Tensor, torch.Tensor]:
  74. probs = logits_to_probs(
  75. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  76. )
  77. idx_next = multinomial_sample_one_no_sync(probs)
  78. return idx_next, probs
  79. def decode_one_token_ar(
  80. model: DualARTransformer,
  81. x: torch.Tensor,
  82. input_pos: torch.Tensor,
  83. semantic_ids: list,
  84. previous_tokens: torch.Tensor = None,
  85. **sampling_kwargs,
  86. ) -> torch.Tensor:
  87. x = model.forward_generate(x, input_pos)
  88. sampling_kwargs_main = sampling_kwargs.copy()
  89. # sampling_kwargs_main["temperature"] = 0.1
  90. # sampling_kwargs_main["top_p"] = 0.1
  91. # sampling_kwargs_main["repetition_penalty"] = 1.0
  92. codebooks = [
  93. sample(
  94. x.logits,
  95. previous_tokens=(
  96. previous_tokens[0] if previous_tokens is not None else None
  97. ), # Disable repetition penalty for the token codebook
  98. **sampling_kwargs_main,
  99. )[0]
  100. ]
  101. hidden_states = x.hidden_states
  102. # Cleanup the cache
  103. for layer in model.fast_layers:
  104. layer.attention.kv_cache.k_cache.fill_(0)
  105. layer.attention.kv_cache.v_cache.fill_(0)
  106. input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
  107. model.forward_generate_fast(hidden_states, input_pos)
  108. a = codebooks[0] - model.tokenizer.semantic_begin_id
  109. a[a < 0] = 0
  110. hidden_states = model.fast_embeddings(a)
  111. codebooks.append(a)
  112. for codebook_idx in range(1, model.config.num_codebooks):
  113. input_pos = torch.tensor(
  114. [codebook_idx], device=hidden_states.device, dtype=torch.long
  115. )
  116. logits = model.forward_generate_fast(hidden_states, input_pos)
  117. chunked_logits = logits[..., :1024]
  118. a = sample(
  119. chunked_logits,
  120. previous_tokens=(
  121. previous_tokens[codebook_idx + 1]
  122. if previous_tokens is not None
  123. else None
  124. ),
  125. **sampling_kwargs,
  126. )[0]
  127. hidden_states = model.fast_embeddings(a)
  128. codebooks.append(a)
  129. codebooks = torch.stack(codebooks, dim=0)
  130. # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
  131. # codebooks[1:, :] = torch.masked_fill(
  132. # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
  133. # )
  134. # print(codebooks)
  135. return codebooks
  136. def decode_n_tokens(
  137. model: NaiveTransformer,
  138. cur_token: torch.Tensor,
  139. input_pos: torch.Tensor,
  140. num_new_tokens: int,
  141. semantic_ids: list,
  142. decode_one_token=decode_one_token_ar,
  143. **sampling_kwargs,
  144. ):
  145. previous_tokens = torch.zeros(
  146. (model.config.num_codebooks + 1, model.config.max_seq_len),
  147. dtype=torch.int,
  148. device=cur_token.device,
  149. )
  150. for i in tqdm(range(num_new_tokens)):
  151. # We need to get windowed repeat penalty
  152. win_size = 16
  153. if i < win_size:
  154. window = previous_tokens[:, :win_size]
  155. else:
  156. window = previous_tokens[:, i - win_size : i]
  157. with (
  158. torch.backends.cuda.sdp_kernel(
  159. enable_flash=False, enable_mem_efficient=False, enable_math=True
  160. )
  161. if torch.cuda.is_available()
  162. else nullcontext()
  163. ): # Actually better for Inductor to codegen attention here
  164. next_token = decode_one_token(
  165. model=model,
  166. x=cur_token,
  167. input_pos=input_pos,
  168. previous_tokens=window,
  169. semantic_ids=semantic_ids,
  170. **sampling_kwargs,
  171. )
  172. input_pos += 1
  173. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  174. previous_tokens[:, i : i + 1] = next_token.view(
  175. model.config.num_codebooks + 1, -1
  176. )
  177. if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
  178. break
  179. return previous_tokens[:, : i + 1]
  180. @torch.no_grad()
  181. @torch.inference_mode()
  182. def generate(
  183. *,
  184. model: NaiveTransformer,
  185. prompt: torch.Tensor,
  186. max_new_tokens: int,
  187. decode_one_token=decode_one_token_ar,
  188. **sampling_kwargs,
  189. ) -> torch.Tensor:
  190. """
  191. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  192. """
  193. # create an empty tensor of the expected final shape and fill in the current tokens
  194. T = prompt.size(1)
  195. # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
  196. semantic_ids = [
  197. model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
  198. ]
  199. if max_new_tokens:
  200. if T + max_new_tokens > model.config.max_seq_len:
  201. max_new_tokens = model.config.max_seq_len - T
  202. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  203. T_new = T + max_new_tokens
  204. else:
  205. T_new = model.config.max_seq_len
  206. max_new_tokens = T_new - T
  207. device, dtype = prompt.device, prompt.dtype
  208. codebook_dim = 1 + model.config.num_codebooks
  209. # create an empty tensor of the expected final shape and fill in the current tokens
  210. empty = torch.empty(
  211. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  212. )
  213. empty[:, :T] = prompt
  214. seq = empty
  215. input_pos = torch.arange(0, T, device=device)
  216. # Use non-accelerated version for now, to avoid compilation overhead
  217. prefill_decode = decode_one_token_ar
  218. next_token = prefill_decode(
  219. model,
  220. prompt.view(1, codebook_dim, -1),
  221. input_pos,
  222. semantic_ids=semantic_ids,
  223. **sampling_kwargs,
  224. )
  225. seq[:, T : T + 1] = next_token
  226. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  227. x = decode_n_tokens(
  228. model,
  229. next_token.view(1, codebook_dim, -1),
  230. input_pos,
  231. max_new_tokens - 1,
  232. decode_one_token=decode_one_token,
  233. semantic_ids=semantic_ids,
  234. **sampling_kwargs,
  235. )
  236. # x = torch.cat(generated_tokens, dim=1)
  237. seq = seq[:, : T + 1 + x.size(1)]
  238. seq[:, T + 1 :] = x
  239. return seq
  240. def load_model(checkpoint_path, device, precision, compile=False):
  241. model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
  242. model = model.to(device=device, dtype=precision)
  243. logger.info(f"Restored model from checkpoint")
  244. if isinstance(model, DualARTransformer):
  245. decode_one_token = decode_one_token_ar
  246. logger.info("Using DualARTransformer")
  247. else:
  248. raise ValueError("Model is not a DualARTransformer")
  249. if compile:
  250. logger.info("Compiling function...")
  251. decode_one_token = torch.compile(
  252. decode_one_token,
  253. fullgraph=True,
  254. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  255. mode="reduce-overhead" if torch.cuda.is_available() else None,
  256. )
  257. return model.eval(), decode_one_token
  258. @dataclass
  259. class GenerateResponse:
  260. action: Literal["sample", "next"]
  261. codes: Optional[torch.Tensor] = None
  262. text: Optional[str] = None
  263. def generate_long(
  264. *,
  265. model,
  266. device: str | torch.device,
  267. decode_one_token: callable,
  268. text: str,
  269. num_samples: int = 1,
  270. max_new_tokens: int = 0,
  271. top_p: int = 0.8,
  272. repetition_penalty: float = 1.1,
  273. temperature: float = 0.8,
  274. compile: bool = False,
  275. iterative_prompt: bool = True,
  276. chunk_length: int = 150,
  277. prompt_text: Optional[str | list[str]] = None,
  278. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  279. ):
  280. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  281. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  282. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  283. use_prompt = prompt_text is not None and prompt_tokens is not None
  284. if use_prompt and isinstance(prompt_text, str):
  285. prompt_text = [prompt_text]
  286. prompt_tokens = [prompt_tokens]
  287. assert use_prompt is False or len(prompt_text) == len(
  288. prompt_tokens
  289. ), "Prompt text and tokens must have the same length"
  290. prompt_tokens = [i.cpu() for i in prompt_tokens]
  291. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  292. tokenizer = model.tokenizer
  293. base_content_sequence = ContentSequence(modality="interleave")
  294. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  295. max_length = model.config.max_seq_len
  296. if use_prompt:
  297. for t, c in zip(prompt_text, prompt_tokens):
  298. base_content_sequence.append(
  299. [
  300. TextPart(text=t),
  301. VQPart(codes=c),
  302. ],
  303. add_end=True,
  304. )
  305. encoded_prompts = base_content_sequence.encode_for_inference(
  306. tokenizer, num_codebooks=model.config.num_codebooks
  307. )
  308. if encoded_prompts.size(1) > max_length - 2048:
  309. raise ValueError(
  310. f"Prompt is too long: {encoded_prompts.size(1)} > {max_length - 2048}"
  311. )
  312. encoded = []
  313. for text in texts:
  314. content_sequence = ContentSequence(modality=None)
  315. content_sequence.append(TextPart(text=text))
  316. encoded.append(
  317. content_sequence.encode_for_inference(
  318. tokenizer, num_codebooks=model.config.num_codebooks
  319. )
  320. )
  321. logger.info(f"Encoded text: {text}")
  322. # Move temperature, top_p, repetition_penalty to device
  323. # This is important so that changing params doesn't trigger recompile
  324. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  325. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  326. repetition_penalty = torch.tensor(
  327. repetition_penalty, device=device, dtype=torch.float
  328. )
  329. for sample_idx in range(num_samples):
  330. if torch.cuda.is_available():
  331. torch.cuda.synchronize()
  332. global_encoded = []
  333. seg_idx = 0
  334. while seg_idx < len(encoded):
  335. logger.info(
  336. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  337. )
  338. seg = encoded[seg_idx]
  339. global_encoded.append(seg)
  340. # Do not use previous segments to generate current segment for now
  341. # lengths = reversed([seg.size(1) for seg in global_encoded])
  342. # # Pick last 2000 tokens
  343. # count = 0
  344. # for i, length in enumerate(lengths):
  345. # count += length
  346. # if count + length > max_length - 2048 - encoded_prompts.size(1):
  347. # break
  348. # if i != 0 and i % 2 == 0:
  349. # i -= 1
  350. # # Rotate the list, always make sure first segment is included to avoid drift
  351. # if i < len(global_encoded) - 2:
  352. # partial_encoded = global_encoded[:2] + global_encoded[-i:]
  353. # else:
  354. # partial_encoded = global_encoded
  355. # cat_encoded = torch.cat([encoded_prompts, *partial_encoded], dim=1)
  356. if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
  357. cat_encoded = torch.cat(
  358. [encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
  359. )
  360. else:
  361. cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
  362. cat_encoded = cat_encoded.to(device=device)
  363. prompt_length = cat_encoded.size(1)
  364. t0 = time.perf_counter()
  365. y = generate(
  366. model=model,
  367. prompt=cat_encoded,
  368. max_new_tokens=max_new_tokens,
  369. decode_one_token=decode_one_token,
  370. temperature=temperature,
  371. top_p=top_p,
  372. repetition_penalty=repetition_penalty,
  373. )
  374. if sample_idx == 0 and seg_idx == 0 and compile:
  375. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  376. if torch.cuda.is_available():
  377. torch.cuda.synchronize()
  378. t = time.perf_counter() - t0
  379. tokens_generated = y.size(1) - prompt_length
  380. tokens_sec = tokens_generated / t
  381. logger.info(
  382. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  383. )
  384. logger.info(
  385. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  386. )
  387. if torch.cuda.is_available():
  388. logger.info(
  389. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  390. )
  391. # Put the generated tokens
  392. # since there is <im_end>, we remove last token
  393. codes = y[1:, prompt_length:-1].clone()
  394. assert (codes >= 0).all(), f"Negative code found"
  395. decoded = y[:, prompt_length:].clone()
  396. # But for global encoding, we should keep the <im_end> token
  397. global_encoded.append(decoded.cpu())
  398. assert (codes >= 0).all(), f"Negative code found: {codes}"
  399. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  400. seg_idx += 1
  401. # This indicates the end of the current sample
  402. yield GenerateResponse(action="next")
  403. @dataclass
  404. class WrappedGenerateResponse:
  405. status: Literal["success", "error"]
  406. response: Optional[GenerateResponse | Exception] = None
  407. @dataclass
  408. class GenerateRequest:
  409. request: dict
  410. response_queue: queue.Queue
  411. def launch_thread_safe_queue(
  412. checkpoint_path,
  413. device,
  414. precision,
  415. compile: bool = False,
  416. ):
  417. input_queue = queue.Queue()
  418. init_event = threading.Event()
  419. def worker():
  420. model, decode_one_token = load_model(
  421. checkpoint_path, device, precision, compile=compile
  422. )
  423. with torch.device(device):
  424. model.setup_caches(
  425. max_batch_size=1,
  426. max_seq_len=model.config.max_seq_len,
  427. dtype=next(model.parameters()).dtype,
  428. )
  429. init_event.set()
  430. while True:
  431. item: GenerateRequest | None = input_queue.get()
  432. if item is None:
  433. break
  434. kwargs = item.request
  435. response_queue = item.response_queue
  436. try:
  437. for chunk in generate_long(
  438. model=model, decode_one_token=decode_one_token, **kwargs
  439. ):
  440. response_queue.put(
  441. WrappedGenerateResponse(status="success", response=chunk)
  442. )
  443. except Exception as e:
  444. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  445. threading.Thread(target=worker, daemon=True).start()
  446. init_event.wait()
  447. return input_queue
  448. def launch_thread_safe_queue_agent(
  449. checkpoint_path,
  450. device,
  451. precision,
  452. compile: bool = False,
  453. ):
  454. input_queue = queue.Queue()
  455. init_event = threading.Event()
  456. tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
  457. config = BaseModelArgs.from_pretrained(checkpoint_path)
  458. def worker():
  459. model, decode_one_token = load_model(
  460. checkpoint_path, device, precision, compile=compile, is_agent=True
  461. )
  462. with torch.device(device):
  463. model.setup_caches(
  464. max_batch_size=1,
  465. max_seq_len=model.config.max_seq_len,
  466. dtype=next(model.parameters()).dtype,
  467. )
  468. init_event.set()
  469. while True:
  470. item: GenerateRequest | None = input_queue.get()
  471. if item is None:
  472. break
  473. kwargs = item.request
  474. response_queue = item.response_queue
  475. try:
  476. for token in generate_agent(
  477. model=model,
  478. decode_one_token=decode_one_token,
  479. **kwargs,
  480. ):
  481. response_queue.put(token)
  482. response_queue.put("stop")
  483. except Exception as e:
  484. import traceback
  485. logger.exception(f"Error in worker: {traceback.format_exc()}")
  486. response_queue.put("error")
  487. threading.Thread(target=worker, daemon=True).start()
  488. init_event.wait()
  489. return input_queue, tokenizer, config
  490. @click.command()
  491. @click.option(
  492. "--text",
  493. type=str,
  494. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  495. )
  496. @click.option("--prompt-text", type=str, default=None, multiple=True)
  497. @click.option(
  498. "--prompt-tokens",
  499. type=click.Path(path_type=Path, exists=True),
  500. default=None,
  501. multiple=True,
  502. )
  503. @click.option("--num-samples", type=int, default=1)
  504. @click.option("--max-new-tokens", type=int, default=0)
  505. @click.option("--top-p", type=float, default=0.8)
  506. @click.option("--repetition-penalty", type=float, default=1.1)
  507. @click.option("--temperature", type=float, default=0.8)
  508. @click.option(
  509. "--checkpoint-path",
  510. type=click.Path(path_type=Path, exists=True),
  511. default="checkpoints/openaudio-s1-mini",
  512. )
  513. @click.option("--device", type=str, default="cuda")
  514. @click.option("--compile/--no-compile", default=False)
  515. @click.option("--seed", type=int, default=42)
  516. @click.option("--half/--no-half", default=False)
  517. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  518. @click.option("--chunk-length", type=int, default=300)
  519. @click.option("--output-dir", type=Path, default="temp")
  520. def main(
  521. text: str,
  522. prompt_text: Optional[list[str]],
  523. prompt_tokens: Optional[list[Path]],
  524. num_samples: int,
  525. max_new_tokens: int,
  526. top_p: int,
  527. repetition_penalty: float,
  528. temperature: float,
  529. checkpoint_path: Path,
  530. device: str,
  531. compile: bool,
  532. seed: int,
  533. half: bool,
  534. iterative_prompt: bool,
  535. chunk_length: int,
  536. output_dir: Path,
  537. ) -> None:
  538. os.makedirs(output_dir, exist_ok=True)
  539. precision = torch.half if half else torch.bfloat16
  540. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  541. raise ValueError(
  542. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  543. )
  544. logger.info("Loading model ...")
  545. t0 = time.time()
  546. model, decode_one_token = load_model(
  547. checkpoint_path, device, precision, compile=compile
  548. )
  549. with torch.device(device):
  550. model.setup_caches(
  551. max_batch_size=1,
  552. max_seq_len=model.config.max_seq_len,
  553. dtype=next(model.parameters()).dtype,
  554. )
  555. if torch.cuda.is_available():
  556. torch.cuda.synchronize()
  557. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  558. if prompt_tokens is not None:
  559. prompt_tokens = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
  560. torch.manual_seed(seed)
  561. if torch.cuda.is_available():
  562. torch.cuda.manual_seed(seed)
  563. generator = generate_long(
  564. model=model,
  565. device=device,
  566. decode_one_token=decode_one_token,
  567. text=text,
  568. num_samples=num_samples,
  569. max_new_tokens=max_new_tokens,
  570. top_p=top_p,
  571. repetition_penalty=repetition_penalty,
  572. temperature=temperature,
  573. compile=compile,
  574. iterative_prompt=iterative_prompt,
  575. chunk_length=chunk_length,
  576. prompt_text=prompt_text,
  577. prompt_tokens=prompt_tokens,
  578. )
  579. idx = 0
  580. codes = []
  581. for response in generator:
  582. if response.action == "sample":
  583. codes.append(response.codes)
  584. logger.info(f"Sampled text: {response.text}")
  585. elif response.action == "next":
  586. if codes:
  587. codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
  588. np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
  589. logger.info(f"Saved codes to {codes_npy_path}")
  590. logger.info(f"Next sample")
  591. codes = []
  592. idx += 1
  593. else:
  594. logger.error(f"Error: {response}")
  595. if __name__ == "__main__":
  596. main()