generate.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from pathlib import Path
  6. from typing import Optional, Tuple, Union
  7. import click
  8. import numpy as np
  9. import torch
  10. import torch._dynamo.config
  11. import torch._inductor.config
  12. from hydra import compose, initialize
  13. from hydra.utils import instantiate
  14. from loguru import logger
  15. from tqdm import tqdm
  16. from transformers import AutoTokenizer
  17. from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
  18. from fish_speech.text.clean import clean_text
  19. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  20. torch._inductor.config.coordinate_descent_tuning = True
  21. torch._inductor.config.triton.unique_kernel_names = True
  22. if hasattr(torch._inductor.config, "fx_graph_cache"):
  23. # Experimental feature to reduce compilation times, will be on by default in future
  24. torch._inductor.config.fx_graph_cache = True
  25. from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
  26. def multinomial_sample_one_no_sync(
  27. probs_sort,
  28. ): # Does multinomial sampling without a cuda synchronization
  29. q = torch.empty_like(probs_sort).exponential_(1)
  30. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  31. def logits_to_probs(
  32. logits,
  33. previous_tokens: Optional[torch.Tensor] = None,
  34. temperature: float = 1.0,
  35. top_k: Optional[int] = None,
  36. top_p: Optional[int] = None,
  37. repetition_penalty: float = 1.0,
  38. ):
  39. if previous_tokens is not None and repetition_penalty != 1.0:
  40. previous_tokens = previous_tokens.long()
  41. score = torch.gather(logits, dim=0, index=previous_tokens)
  42. score = torch.where(
  43. score < 0, score * repetition_penalty, score / repetition_penalty
  44. )
  45. logits.scatter_(dim=0, index=previous_tokens, src=score)
  46. if top_p is not None and top_p < 1.0:
  47. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  48. cum_probs = torch.cumsum(
  49. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  50. )
  51. sorted_indices_to_remove = cum_probs > top_p
  52. sorted_indices_to_remove[0] = False # keep at least one option
  53. indices_to_remove = sorted_indices_to_remove.scatter(
  54. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  55. )
  56. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  57. logits = logits / max(temperature, 1e-5)
  58. if top_k is not None:
  59. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  60. pivot = v.select(-1, -1).unsqueeze(-1)
  61. logits = torch.where(logits < pivot, -float("Inf"), logits)
  62. probs = torch.nn.functional.softmax(logits, dim=-1)
  63. return probs
  64. def sample(
  65. logits,
  66. previous_tokens: Optional[torch.Tensor] = None,
  67. **sampling_kwargs,
  68. ) -> Tuple[torch.Tensor, torch.Tensor]:
  69. probs = logits_to_probs(
  70. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  71. )
  72. idx_next = multinomial_sample_one_no_sync(probs)
  73. return idx_next, probs
  74. def decode_one_token_ar(
  75. model: DualARTransformer,
  76. x: torch.Tensor,
  77. input_pos: torch.Tensor,
  78. previous_tokens: torch.Tensor = None,
  79. **sampling_kwargs,
  80. ) -> torch.Tensor:
  81. x = model.forward_generate(x, input_pos)
  82. codebooks = [
  83. sample(
  84. x.logits,
  85. previous_tokens=None, # Disable repetition penalty for the token codebook
  86. **sampling_kwargs,
  87. )[0]
  88. ]
  89. x = x.hidden_states
  90. # Cleanup the cache
  91. for layer in model.fast_layers:
  92. layer.attention.kv_cache.k_cache.fill_(0)
  93. layer.attention.kv_cache.v_cache.fill_(0)
  94. for codebook_idx in range(model.config.num_codebooks):
  95. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  96. logits = model.forward_generate_fast(x, input_pos)
  97. a = sample(
  98. logits,
  99. previous_tokens=(
  100. previous_tokens[codebook_idx + 1]
  101. if previous_tokens is not None
  102. else None
  103. ),
  104. **sampling_kwargs,
  105. )[0]
  106. x = model.fast_embeddings(a)
  107. codebooks.append(a)
  108. return torch.stack(codebooks, dim=0)
  109. def decode_one_token_naive(
  110. model: NaiveTransformer,
  111. x: torch.Tensor,
  112. input_pos: torch.Tensor,
  113. previous_tokens: torch.Tensor = None,
  114. **sampling_kwargs,
  115. ) -> torch.Tensor:
  116. x = model.forward_generate(x, input_pos)
  117. codebooks = [
  118. sample(
  119. x.token_logits,
  120. previous_tokens=None, # Disable repetition penalty for the token codebook
  121. **sampling_kwargs,
  122. )[0]
  123. ]
  124. for i in range(model.config.num_codebooks):
  125. codebooks.append(
  126. sample(
  127. x.codebook_logits[:, :, i],
  128. previous_tokens=(
  129. previous_tokens[i + 1] if previous_tokens is not None else None
  130. ),
  131. **sampling_kwargs,
  132. )[0]
  133. )
  134. return torch.stack(codebooks, dim=0)
  135. def decode_n_tokens(
  136. model: NaiveTransformer,
  137. cur_token: torch.Tensor,
  138. input_pos: torch.Tensor,
  139. num_new_tokens: int,
  140. eos_token_id: int = 2,
  141. im_end_id: int = 4,
  142. decode_one_token=decode_one_token_naive,
  143. **sampling_kwargs,
  144. ):
  145. previous_tokens = torch.zeros(
  146. (model.config.num_codebooks + 1, model.config.max_seq_len),
  147. dtype=torch.int,
  148. device=cur_token.device,
  149. )
  150. for i in tqdm(range(num_new_tokens)):
  151. # We need to get windowed repeat penalty
  152. win_size = 16
  153. if i < win_size:
  154. window = previous_tokens[:, :win_size]
  155. else:
  156. window = previous_tokens[:, i - win_size : i]
  157. with torch.backends.cuda.sdp_kernel(
  158. enable_flash=False, enable_mem_efficient=False, enable_math=True
  159. ): # Actually better for Inductor to codegen attention here
  160. next_token = decode_one_token(
  161. model=model,
  162. x=cur_token,
  163. input_pos=input_pos,
  164. previous_tokens=window,
  165. **sampling_kwargs,
  166. )
  167. input_pos += 1
  168. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  169. previous_tokens[:, i : i + 1] = next_token.view(
  170. model.config.num_codebooks + 1, -1
  171. )
  172. if (
  173. cur_token[0, 0, -1] == eos_token_id
  174. or cur_token[0, 0, -1] == im_end_id
  175. or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
  176. ):
  177. break
  178. return previous_tokens[:, : i + 1]
  179. @torch.no_grad()
  180. @torch.inference_mode()
  181. def generate(
  182. *,
  183. model: NaiveTransformer,
  184. prompt: torch.Tensor,
  185. max_new_tokens: int,
  186. eos_token_id: int = 2,
  187. im_end_id: int = 4,
  188. decode_one_token=decode_one_token_naive,
  189. **sampling_kwargs,
  190. ) -> torch.Tensor:
  191. """
  192. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  193. """
  194. # create an empty tensor of the expected final shape and fill in the current tokens
  195. T = prompt.size(1)
  196. if max_new_tokens:
  197. if T + max_new_tokens > model.config.max_seq_len:
  198. max_new_tokens = model.config.max_seq_len - T
  199. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  200. T_new = T + max_new_tokens
  201. else:
  202. T_new = model.config.max_seq_len
  203. max_new_tokens = T_new - T
  204. device, dtype = prompt.device, prompt.dtype
  205. with torch.device(device):
  206. model.setup_caches(
  207. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  208. )
  209. codebook_dim = 1 + model.config.num_codebooks
  210. # create an empty tensor of the expected final shape and fill in the current tokens
  211. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  212. empty[:, :T] = prompt
  213. seq = empty
  214. input_pos = torch.arange(0, T, device=device)
  215. # Use non-accelerated version for now, to avoid compilation overhead
  216. prefill_decode = (
  217. decode_one_token_naive
  218. if isinstance(model, NaiveTransformer)
  219. else decode_one_token_ar
  220. )
  221. next_token = prefill_decode(
  222. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  223. )
  224. seq[:, T : T + 1] = next_token
  225. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  226. x = decode_n_tokens(
  227. model,
  228. next_token.view(1, codebook_dim, -1),
  229. input_pos,
  230. max_new_tokens - 1,
  231. eos_token_id=eos_token_id,
  232. im_end_id=im_end_id,
  233. decode_one_token=decode_one_token,
  234. **sampling_kwargs,
  235. )
  236. # x = torch.cat(generated_tokens, dim=1)
  237. seq = seq[:, : T + 1 + x.size(1)]
  238. seq[:, T + 1 :] = x
  239. return seq
  240. def encode_tokens(
  241. tokenizer,
  242. string,
  243. bos=True,
  244. device="cuda",
  245. prompt_tokens=None,
  246. speaker=None,
  247. num_codebooks=4,
  248. ):
  249. string = clean_text(string)
  250. if speaker is None:
  251. speaker = "assistant"
  252. string = (
  253. f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
  254. )
  255. if bos:
  256. string = f"<|begin_of_sequence|>{string}"
  257. new_tokens = tokenizer.encode(
  258. string,
  259. add_special_tokens=False,
  260. max_length=10**6,
  261. truncation=False,
  262. )
  263. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  264. # Codebooks
  265. zeros = (
  266. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  267. * CODEBOOK_PAD_TOKEN_ID
  268. )
  269. prompt = torch.cat((tokens, zeros), dim=0)
  270. if prompt_tokens is None:
  271. return prompt
  272. # Get prompt tokens
  273. if prompt_tokens.ndim == 3:
  274. assert (
  275. prompt_tokens.shape[0] == 1
  276. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  277. prompt_tokens = prompt_tokens[0]
  278. assert prompt_tokens.ndim == 2
  279. data = prompt_tokens + 2
  280. if prompt_tokens.shape[0] > num_codebooks:
  281. logger.warning(
  282. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  283. )
  284. data = data[:num_codebooks]
  285. # Add eos token for each codebook
  286. data = torch.cat(
  287. (
  288. data,
  289. torch.ones((data.size(0), 1), dtype=torch.int, device=device)
  290. * CODEBOOK_EOS_TOKEN_ID,
  291. ),
  292. dim=1,
  293. )
  294. # Since 1.0, we use <|semantic|>
  295. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  296. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  297. main_token_ids = (
  298. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  299. )
  300. main_token_ids[0, -1] = end_token_id
  301. data = torch.cat((main_token_ids, data), dim=0)
  302. prompt = torch.cat((prompt, data), dim=1)
  303. return prompt
  304. def load_model(
  305. config_name, checkpoint_path, device, precision, max_length, compile=False
  306. ):
  307. with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
  308. cfg = compose(
  309. config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
  310. )
  311. model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
  312. if "int8" in str(checkpoint_path):
  313. logger.info("Using int8 weight-only quantization!")
  314. from quantize import WeightOnlyInt8QuantHandler
  315. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  316. model = simple_quantizer.convert_for_runtime()
  317. if "int4" in str(checkpoint_path):
  318. logger.info("Using int4 quantization!")
  319. path_comps = checkpoint_path.name.split(".")
  320. assert path_comps[-2].startswith("g")
  321. groupsize = int(path_comps[-2][1:])
  322. from quantize import WeightOnlyInt4QuantHandler
  323. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  324. model = simple_quantizer.convert_for_runtime()
  325. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  326. if "state_dict" in checkpoint:
  327. checkpoint = checkpoint["state_dict"]
  328. if any(k.startswith("model.") for k in checkpoint):
  329. checkpoint = {
  330. k.replace("model.", ""): v
  331. for k, v in checkpoint.items()
  332. if k.startswith("model.")
  333. }
  334. model.load_state_dict(checkpoint, assign=True)
  335. model = model.to(device=device, dtype=precision)
  336. logger.info("Restored model from checkpoint")
  337. if isinstance(model, DualARTransformer):
  338. decode_one_token = decode_one_token_ar
  339. logger.info("Using DualARTransformer")
  340. else:
  341. decode_one_token = decode_one_token_naive
  342. logger.info("Using NaiveTransformer")
  343. if compile:
  344. logger.info("Compiling function...")
  345. decode_one_token = torch.compile(
  346. decode_one_token, mode="reduce-overhead", fullgraph=True
  347. )
  348. return model.eval(), decode_one_token
  349. def split_text(text, min_length):
  350. text = clean_text(text)
  351. segments = []
  352. curr = ""
  353. for char in text:
  354. curr += char
  355. if char not in [".", ",", "!", "?"]:
  356. continue
  357. if len(curr) >= min_length:
  358. segments.append(curr)
  359. curr = ""
  360. if curr:
  361. segments.append(curr)
  362. return segments
  363. def generate_long(
  364. *,
  365. model,
  366. tokenizer: callable,
  367. device: str | torch.device,
  368. decode_one_token: callable,
  369. text: str,
  370. num_samples: int = 1,
  371. max_new_tokens: int = 0,
  372. top_k: int = None,
  373. top_p: int = 0.7,
  374. repetition_penalty: float = 1.5,
  375. temperature: float = 0.7,
  376. compile: bool = False,
  377. iterative_prompt: bool = True,
  378. max_length: int = 2048,
  379. chunk_length: int = 30,
  380. speaker: Optional[str] = None,
  381. prompt_text: Optional[str] = None,
  382. prompt_tokens: Optional[torch.Tensor] = None,
  383. is_streaming: bool = False,
  384. ):
  385. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  386. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  387. use_prompt = prompt_text is not None and prompt_tokens is not None
  388. encoded = []
  389. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  390. if use_prompt:
  391. encoded.append(
  392. encode_tokens(
  393. tokenizer,
  394. prompt_text,
  395. prompt_tokens=prompt_tokens,
  396. bos=True,
  397. device=device,
  398. speaker=speaker,
  399. num_codebooks=model.config.num_codebooks,
  400. )
  401. )
  402. for idx, text in enumerate(texts):
  403. encoded.append(
  404. encode_tokens(
  405. tokenizer,
  406. string=text,
  407. bos=idx == 0 and not use_prompt,
  408. device=device,
  409. speaker=speaker,
  410. num_codebooks=model.config.num_codebooks,
  411. )
  412. )
  413. logger.info(f"Encoded text: {text}")
  414. for sample_idx in range(num_samples):
  415. torch.cuda.synchronize()
  416. global_encoded = []
  417. all_codes = []
  418. seg_idx = 0
  419. if use_prompt:
  420. seg_idx = 1
  421. global_encoded.append(encoded[0])
  422. while seg_idx < len(encoded):
  423. logger.info(
  424. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  425. )
  426. seg = encoded[seg_idx]
  427. global_encoded.append(seg)
  428. lengths = reversed([seg.size(1) for seg in global_encoded])
  429. # Pick last 2000 tokens
  430. count = 0
  431. for i, length in enumerate(lengths):
  432. count += length
  433. if count + length > max_length - 1024:
  434. break
  435. if i != 0 and i % 2 == 0:
  436. i -= 1
  437. # Rotate the list, always make sure first segment is included to avoid drift
  438. if i < len(global_encoded) - 2:
  439. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  440. else:
  441. partial_encoded = global_encoded
  442. cat_encoded = torch.cat(partial_encoded, dim=1)
  443. prompt_length = cat_encoded.size(1)
  444. t0 = time.perf_counter()
  445. y = generate(
  446. model=model,
  447. prompt=cat_encoded,
  448. max_new_tokens=max_new_tokens,
  449. eos_token_id=tokenizer.eos_token_id,
  450. im_end_id=im_end_id,
  451. decode_one_token=decode_one_token,
  452. temperature=temperature,
  453. top_k=top_k,
  454. top_p=top_p,
  455. repetition_penalty=repetition_penalty,
  456. )
  457. if sample_idx == 0 and seg_idx == 0 and compile:
  458. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  459. torch.cuda.synchronize()
  460. t = time.perf_counter() - t0
  461. tokens_generated = y.size(1) - prompt_length
  462. tokens_sec = tokens_generated / t
  463. logger.info(
  464. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  465. )
  466. logger.info(
  467. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  468. )
  469. logger.info(
  470. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  471. )
  472. # Put the generated tokens
  473. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  474. codes = y[1:, prompt_length:-2].clone()
  475. codes = codes - 2
  476. assert (codes >= 0).all(), f"Negative code found"
  477. decoded = y[:, prompt_length:-1].clone()
  478. if decoded[0, -1] != im_end_id: # <im_end>
  479. val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
  480. decoded = torch.cat(
  481. (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
  482. )
  483. # But for global encoding, we should keep the <im_end> token
  484. global_encoded.append(decoded)
  485. if is_streaming:
  486. assert (codes >= 0).all(), f"Negative code found: {codes}"
  487. yield codes
  488. else:
  489. all_codes.append(codes)
  490. seg_idx += 1
  491. if is_streaming:
  492. # This indicates the end of the current sample
  493. yield None
  494. else:
  495. all_codes = torch.cat(all_codes, dim=1)
  496. assert (all_codes >= 0).all(), f"Negative code found: {codes}"
  497. yield all_codes
  498. def launch_thread_safe_queue(
  499. config_name,
  500. checkpoint_path,
  501. device,
  502. precision,
  503. max_length,
  504. compile=False,
  505. ):
  506. input_queue = queue.Queue()
  507. def worker():
  508. model, decode_one_token = load_model(
  509. config_name, checkpoint_path, device, precision, max_length, compile=compile
  510. )
  511. while True:
  512. item = input_queue.get()
  513. if item is None:
  514. break
  515. kwargs = item["request"]
  516. event = item["event"]
  517. try:
  518. item["success"] = True
  519. item["response"] = list(
  520. generate_long(
  521. model=model, decode_one_token=decode_one_token, **kwargs
  522. )
  523. )
  524. except Exception as e:
  525. item["success"] = False
  526. item["response"] = e
  527. event.set()
  528. threading.Thread(target=worker, daemon=True).start()
  529. return input_queue
  530. @click.command()
  531. @click.option(
  532. "--text",
  533. type=str,
  534. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  535. )
  536. @click.option("--prompt-text", type=str, default=None)
  537. @click.option(
  538. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  539. )
  540. @click.option("--num-samples", type=int, default=1)
  541. @click.option("--max-new-tokens", type=int, default=0)
  542. @click.option("--top-k", type=int, default=None)
  543. @click.option("--top-p", type=float, default=0.7)
  544. @click.option("--repetition-penalty", type=float, default=1.5)
  545. @click.option("--temperature", type=float, default=0.7)
  546. @click.option(
  547. "--checkpoint-path",
  548. type=click.Path(path_type=Path, exists=True),
  549. default="results/text2semantic_400m_finetune/step_000002000.pth",
  550. )
  551. @click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
  552. @click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
  553. @click.option("--compile/--no-compile", default=False)
  554. @click.option("--seed", type=int, default=42)
  555. @click.option("--speaker", type=str, default=None)
  556. @click.option("--half/--no-half", default=False)
  557. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  558. @click.option("--max-length", type=int, default=2048)
  559. @click.option("--chunk-length", type=int, default=30)
  560. def main(
  561. text: str,
  562. prompt_text: Optional[str],
  563. prompt_tokens: Optional[Path],
  564. num_samples: int,
  565. max_new_tokens: int,
  566. top_k: int,
  567. top_p: int,
  568. repetition_penalty: float,
  569. temperature: float,
  570. checkpoint_path: Path,
  571. config_name: str,
  572. tokenizer: str,
  573. compile: bool,
  574. seed: int,
  575. speaker: Optional[str],
  576. half: bool,
  577. iterative_prompt: bool,
  578. max_length: int,
  579. chunk_length: int,
  580. ) -> None:
  581. device = "cuda"
  582. precision = torch.half if half else torch.bfloat16
  583. logger.info("Loading model ...")
  584. t0 = time.time()
  585. model, decode_one_token = load_model(
  586. config_name, checkpoint_path, device, precision, max_length, compile=compile
  587. )
  588. torch.cuda.synchronize()
  589. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  590. prompt_tokens = (
  591. torch.from_numpy(np.load(prompt_tokens)).to(device)
  592. if prompt_tokens is not None
  593. else None
  594. )
  595. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  596. torch.manual_seed(seed)
  597. torch.cuda.manual_seed(seed)
  598. generator = generate_long(
  599. model=model,
  600. device=device,
  601. decode_one_token=decode_one_token,
  602. text=text,
  603. num_samples=num_samples,
  604. max_new_tokens=max_new_tokens,
  605. top_k=top_k,
  606. top_p=top_p,
  607. repetition_penalty=repetition_penalty,
  608. temperature=temperature,
  609. tokenizer=tokenizer,
  610. compile=compile,
  611. speaker=speaker,
  612. iterative_prompt=iterative_prompt,
  613. max_length=max_length,
  614. chunk_length=chunk_length,
  615. prompt_text=prompt_text,
  616. prompt_tokens=prompt_tokens,
  617. )
  618. for idx, codes in enumerate(generator):
  619. np.save(f"codes_{idx}.npy", codes.cpu().numpy())
  620. logger.info(f"Saved codes to codes_{idx}.npy")
  621. if __name__ == "__main__":
  622. main()