generate.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from pathlib import Path
  6. from typing import Optional, Tuple, Union
  7. import click
  8. import hydra
  9. import numpy as np
  10. import torch
  11. import torch._dynamo.config
  12. import torch._inductor.config
  13. from hydra import compose, initialize
  14. from hydra.utils import instantiate
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from transformers import AutoTokenizer
  18. from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
  19. from fish_speech.text.clean import clean_text
  20. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  21. torch._inductor.config.coordinate_descent_tuning = True
  22. torch._inductor.config.triton.unique_kernel_names = True
  23. if hasattr(torch._inductor.config, "fx_graph_cache"):
  24. # Experimental feature to reduce compilation times, will be on by default in future
  25. torch._inductor.config.fx_graph_cache = True
  26. from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
  27. def multinomial_sample_one_no_sync(
  28. probs_sort,
  29. ): # Does multinomial sampling without a cuda synchronization
  30. q = torch.empty_like(probs_sort).exponential_(1)
  31. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  32. def logits_to_probs(
  33. logits,
  34. previous_tokens: Optional[torch.Tensor] = None,
  35. temperature: float = 1.0,
  36. top_k: Optional[int] = None,
  37. top_p: Optional[int] = None,
  38. repetition_penalty: float = 1.0,
  39. ):
  40. if previous_tokens is not None and repetition_penalty != 1.0:
  41. previous_tokens = previous_tokens.long()
  42. score = torch.gather(logits, dim=0, index=previous_tokens)
  43. score = torch.where(
  44. score < 0, score * repetition_penalty, score / repetition_penalty
  45. )
  46. logits.scatter_(dim=0, index=previous_tokens, src=score)
  47. if top_p is not None and top_p < 1.0:
  48. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  49. cum_probs = torch.cumsum(
  50. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  51. )
  52. sorted_indices_to_remove = cum_probs > top_p
  53. sorted_indices_to_remove[0] = False # keep at least one option
  54. indices_to_remove = sorted_indices_to_remove.scatter(
  55. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  56. )
  57. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  58. logits = logits / max(temperature, 1e-5)
  59. if top_k is not None:
  60. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  61. pivot = v.select(-1, -1).unsqueeze(-1)
  62. logits = torch.where(logits < pivot, -float("Inf"), logits)
  63. probs = torch.nn.functional.softmax(logits, dim=-1)
  64. return probs
  65. def sample(
  66. logits,
  67. previous_tokens: Optional[torch.Tensor] = None,
  68. **sampling_kwargs,
  69. ) -> Tuple[torch.Tensor, torch.Tensor]:
  70. probs = logits_to_probs(
  71. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  72. )
  73. idx_next = multinomial_sample_one_no_sync(probs)
  74. return idx_next, probs
  75. def decode_one_token_ar(
  76. model: DualARTransformer,
  77. x: torch.Tensor,
  78. input_pos: torch.Tensor,
  79. previous_tokens: torch.Tensor = None,
  80. **sampling_kwargs,
  81. ) -> torch.Tensor:
  82. x = model.forward_generate(x, input_pos)
  83. codebooks = [
  84. sample(
  85. x.logits,
  86. previous_tokens=None, # Disable repetition penalty for the token codebook
  87. **sampling_kwargs,
  88. )[0]
  89. ]
  90. x = x.hidden_states
  91. # Cleanup the cache
  92. for layer in model.fast_layers:
  93. layer.attention.kv_cache.k_cache.fill_(0)
  94. layer.attention.kv_cache.v_cache.fill_(0)
  95. for codebook_idx in range(model.config.num_codebooks):
  96. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  97. logits = model.forward_generate_fast(x, input_pos)
  98. a = sample(
  99. logits,
  100. previous_tokens=(
  101. previous_tokens[codebook_idx + 1]
  102. if previous_tokens is not None
  103. else None
  104. ),
  105. **sampling_kwargs,
  106. )[0]
  107. x = model.fast_embeddings(a)
  108. codebooks.append(a)
  109. return torch.stack(codebooks, dim=0)
  110. def decode_one_token_naive(
  111. model: NaiveTransformer,
  112. x: torch.Tensor,
  113. input_pos: torch.Tensor,
  114. previous_tokens: torch.Tensor = None,
  115. **sampling_kwargs,
  116. ) -> torch.Tensor:
  117. x = model.forward_generate(x, input_pos)
  118. codebooks = [
  119. sample(
  120. x.token_logits,
  121. previous_tokens=None, # Disable repetition penalty for the token codebook
  122. **sampling_kwargs,
  123. )[0]
  124. ]
  125. for i in range(model.config.num_codebooks):
  126. codebooks.append(
  127. sample(
  128. x.codebook_logits[:, :, i],
  129. previous_tokens=(
  130. previous_tokens[i + 1] if previous_tokens is not None else None
  131. ),
  132. **sampling_kwargs,
  133. )[0]
  134. )
  135. return torch.stack(codebooks, dim=0)
  136. def decode_n_tokens(
  137. model: NaiveTransformer,
  138. cur_token: torch.Tensor,
  139. input_pos: torch.Tensor,
  140. num_new_tokens: int,
  141. eos_token_id: int = 2,
  142. im_end_id: int = 4,
  143. decode_one_token=decode_one_token_naive,
  144. **sampling_kwargs,
  145. ):
  146. previous_tokens = torch.zeros(
  147. (model.config.num_codebooks + 1, model.config.max_seq_len),
  148. dtype=torch.int,
  149. device=cur_token.device,
  150. )
  151. for i in tqdm(range(num_new_tokens)):
  152. # We need to get windowed repeat penalty
  153. win_size = 16
  154. if i < win_size:
  155. window = previous_tokens[:, :win_size]
  156. else:
  157. window = previous_tokens[:, i - win_size : i]
  158. with torch.backends.cuda.sdp_kernel(
  159. enable_flash=False, enable_mem_efficient=False, enable_math=True
  160. ): # Actually better for Inductor to codegen attention here
  161. next_token = decode_one_token(
  162. model=model,
  163. x=cur_token,
  164. input_pos=input_pos,
  165. previous_tokens=window,
  166. **sampling_kwargs,
  167. )
  168. input_pos += 1
  169. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  170. previous_tokens[:, i : i + 1] = next_token.view(
  171. model.config.num_codebooks + 1, -1
  172. )
  173. if (
  174. cur_token[0, 0, -1] == eos_token_id
  175. or cur_token[0, 0, -1] == im_end_id
  176. or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
  177. ):
  178. break
  179. return previous_tokens[:, : i + 1]
  180. @torch.no_grad()
  181. @torch.inference_mode()
  182. def generate(
  183. *,
  184. model: NaiveTransformer,
  185. prompt: torch.Tensor,
  186. max_new_tokens: int,
  187. eos_token_id: int = 2,
  188. im_end_id: int = 4,
  189. decode_one_token=decode_one_token_naive,
  190. **sampling_kwargs,
  191. ) -> torch.Tensor:
  192. """
  193. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  194. """
  195. # create an empty tensor of the expected final shape and fill in the current tokens
  196. T = prompt.size(1)
  197. if max_new_tokens:
  198. if T + max_new_tokens > model.config.max_seq_len:
  199. max_new_tokens = model.config.max_seq_len - T
  200. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  201. T_new = T + max_new_tokens
  202. else:
  203. T_new = model.config.max_seq_len
  204. max_new_tokens = T_new - T
  205. device, dtype = prompt.device, prompt.dtype
  206. with torch.device(device):
  207. model.setup_caches(
  208. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  209. )
  210. codebook_dim = 1 + model.config.num_codebooks
  211. # create an empty tensor of the expected final shape and fill in the current tokens
  212. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  213. empty[:, :T] = prompt
  214. seq = empty
  215. input_pos = torch.arange(0, T, device=device)
  216. # Use non-accelerated version for now, to avoid compilation overhead
  217. prefill_decode = (
  218. decode_one_token_naive
  219. if isinstance(model, NaiveTransformer)
  220. else decode_one_token_ar
  221. )
  222. next_token = prefill_decode(
  223. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  224. )
  225. seq[:, T : T + 1] = next_token
  226. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  227. x = decode_n_tokens(
  228. model,
  229. next_token.view(1, codebook_dim, -1),
  230. input_pos,
  231. max_new_tokens - 1,
  232. eos_token_id=eos_token_id,
  233. im_end_id=im_end_id,
  234. decode_one_token=decode_one_token,
  235. **sampling_kwargs,
  236. )
  237. # x = torch.cat(generated_tokens, dim=1)
  238. seq = seq[:, : T + 1 + x.size(1)]
  239. seq[:, T + 1 :] = x
  240. return seq
  241. def encode_tokens(
  242. tokenizer,
  243. string,
  244. bos=True,
  245. device="cuda",
  246. prompt_tokens=None,
  247. speaker=None,
  248. num_codebooks=4,
  249. ):
  250. string = clean_text(string)
  251. if speaker is None:
  252. speaker = "assistant"
  253. string = (
  254. f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
  255. )
  256. if bos:
  257. string = f"<|begin_of_sequence|>{string}"
  258. new_tokens = tokenizer.encode(
  259. string,
  260. add_special_tokens=False,
  261. max_length=10**6,
  262. truncation=False,
  263. )
  264. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  265. # Codebooks
  266. zeros = (
  267. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  268. * CODEBOOK_PAD_TOKEN_ID
  269. )
  270. prompt = torch.cat((tokens, zeros), dim=0)
  271. if prompt_tokens is None:
  272. return prompt
  273. # Get prompt tokens
  274. if prompt_tokens.ndim == 3:
  275. assert (
  276. prompt_tokens.shape[0] == 1
  277. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  278. prompt_tokens = prompt_tokens[0]
  279. assert prompt_tokens.ndim == 2
  280. data = prompt_tokens + 2
  281. if prompt_tokens.shape[0] > num_codebooks:
  282. logger.warning(
  283. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  284. )
  285. data = data[:num_codebooks]
  286. # Add eos token for each codebook
  287. data = torch.cat(
  288. (
  289. data,
  290. torch.ones((data.size(0), 1), dtype=torch.int, device=device)
  291. * CODEBOOK_EOS_TOKEN_ID,
  292. ),
  293. dim=1,
  294. )
  295. # Since 1.0, we use <|semantic|>
  296. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  297. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  298. main_token_ids = (
  299. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  300. )
  301. main_token_ids[0, -1] = end_token_id
  302. data = torch.cat((main_token_ids, data), dim=0)
  303. prompt = torch.cat((prompt, data), dim=1)
  304. return prompt
  305. def load_model(
  306. config_name, checkpoint_path, device, precision, max_length, compile=False
  307. ):
  308. hydra.core.global_hydra.GlobalHydra.instance().clear()
  309. with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
  310. cfg = compose(
  311. config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
  312. )
  313. model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
  314. if "int8" in str(checkpoint_path):
  315. logger.info("Using int8 weight-only quantization!")
  316. from quantize import WeightOnlyInt8QuantHandler
  317. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  318. model = simple_quantizer.convert_for_runtime()
  319. if "int4" in str(checkpoint_path):
  320. logger.info("Using int4 quantization!")
  321. path_comps = checkpoint_path.name.split(".")
  322. assert path_comps[-2].startswith("g")
  323. groupsize = int(path_comps[-2][1:])
  324. from quantize import WeightOnlyInt4QuantHandler
  325. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  326. model = simple_quantizer.convert_for_runtime()
  327. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  328. if "state_dict" in checkpoint:
  329. checkpoint = checkpoint["state_dict"]
  330. if any(k.startswith("model.") for k in checkpoint):
  331. checkpoint = {
  332. k.replace("model.", ""): v
  333. for k, v in checkpoint.items()
  334. if k.startswith("model.")
  335. }
  336. model.load_state_dict(checkpoint, assign=True)
  337. model = model.to(device=device, dtype=precision)
  338. logger.info("Restored model from checkpoint")
  339. if isinstance(model, DualARTransformer):
  340. decode_one_token = decode_one_token_ar
  341. logger.info("Using DualARTransformer")
  342. else:
  343. decode_one_token = decode_one_token_naive
  344. logger.info("Using NaiveTransformer")
  345. if compile:
  346. logger.info("Compiling function...")
  347. decode_one_token = torch.compile(
  348. decode_one_token, mode="reduce-overhead", fullgraph=True
  349. )
  350. return model.eval(), decode_one_token
  351. def split_text(text, min_length):
  352. text = clean_text(text)
  353. segments = []
  354. curr = ""
  355. for char in text:
  356. curr += char
  357. if char not in [".", ",", "!", "?"]:
  358. continue
  359. if len(curr) >= min_length:
  360. segments.append(curr)
  361. curr = ""
  362. if curr:
  363. segments.append(curr)
  364. return segments
  365. def generate_long(
  366. *,
  367. model,
  368. tokenizer: callable,
  369. device: str | torch.device,
  370. decode_one_token: callable,
  371. text: str,
  372. num_samples: int = 1,
  373. max_new_tokens: int = 0,
  374. top_k: int = None,
  375. top_p: int = 0.7,
  376. repetition_penalty: float = 1.5,
  377. temperature: float = 0.7,
  378. compile: bool = False,
  379. iterative_prompt: bool = True,
  380. max_length: int = 2048,
  381. chunk_length: int = 30,
  382. speaker: Optional[str] = None,
  383. prompt_text: Optional[str] = None,
  384. prompt_tokens: Optional[torch.Tensor] = None,
  385. is_streaming: bool = False,
  386. ):
  387. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  388. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  389. use_prompt = prompt_text is not None and prompt_tokens is not None
  390. encoded = []
  391. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  392. if use_prompt:
  393. encoded_prompts = encode_tokens(
  394. tokenizer,
  395. prompt_text,
  396. prompt_tokens=prompt_tokens,
  397. bos=True,
  398. device=device,
  399. speaker=speaker,
  400. num_codebooks=model.config.num_codebooks,
  401. )
  402. for idx, text in enumerate(texts):
  403. encoded.append(
  404. encode_tokens(
  405. tokenizer,
  406. string=text,
  407. bos=idx == 0 and not use_prompt,
  408. device=device,
  409. speaker=speaker,
  410. num_codebooks=model.config.num_codebooks,
  411. )
  412. )
  413. logger.info(f"Encoded text: {text}")
  414. for sample_idx in range(num_samples):
  415. torch.cuda.synchronize()
  416. global_encoded = []
  417. all_codes = []
  418. seg_idx = 0
  419. while seg_idx < len(encoded):
  420. logger.info(
  421. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  422. )
  423. seg = encoded[seg_idx]
  424. global_encoded.append(seg)
  425. lengths = reversed([seg.size(1) for seg in global_encoded])
  426. # Pick last 2000 tokens
  427. count = 0
  428. for i, length in enumerate(lengths):
  429. count += length
  430. if count + length > max_length - 1024:
  431. break
  432. if i != 0 and i % 2 == 0:
  433. i -= 1
  434. # Rotate the list, always make sure first segment is included to avoid drift
  435. if i < len(global_encoded) - 2:
  436. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  437. else:
  438. partial_encoded = global_encoded
  439. if use_prompt:
  440. partial_encoded = [encoded_prompts] + partial_encoded
  441. cat_encoded = torch.cat(partial_encoded, dim=1)
  442. prompt_length = cat_encoded.size(1)
  443. t0 = time.perf_counter()
  444. y = generate(
  445. model=model,
  446. prompt=cat_encoded,
  447. max_new_tokens=max_new_tokens,
  448. eos_token_id=tokenizer.eos_token_id,
  449. im_end_id=im_end_id,
  450. decode_one_token=decode_one_token,
  451. temperature=temperature,
  452. top_k=top_k,
  453. top_p=top_p,
  454. repetition_penalty=repetition_penalty,
  455. )
  456. if sample_idx == 0 and seg_idx == 0 and compile:
  457. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  458. torch.cuda.synchronize()
  459. t = time.perf_counter() - t0
  460. tokens_generated = y.size(1) - prompt_length
  461. tokens_sec = tokens_generated / t
  462. logger.info(
  463. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  464. )
  465. logger.info(
  466. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  467. )
  468. logger.info(
  469. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  470. )
  471. # Put the generated tokens
  472. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  473. codes = y[1:, prompt_length:-2].clone()
  474. codes = codes - 2
  475. assert (codes >= 0).all(), f"Negative code found"
  476. decoded = y[:, prompt_length:-1].clone()
  477. if decoded[0, -1] != im_end_id: # <im_end>
  478. val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
  479. decoded = torch.cat(
  480. (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
  481. )
  482. # But for global encoding, we should keep the <im_end> token
  483. global_encoded.append(decoded)
  484. if is_streaming:
  485. assert (codes >= 0).all(), f"Negative code found: {codes}"
  486. yield codes
  487. else:
  488. all_codes.append(codes)
  489. seg_idx += 1
  490. if is_streaming:
  491. # This indicates the end of the current sample
  492. yield "next"
  493. else:
  494. all_codes = torch.cat(all_codes, dim=1)
  495. assert (all_codes >= 0).all(), f"Negative code found: {codes}"
  496. yield all_codes
  497. def launch_thread_safe_queue(
  498. config_name,
  499. checkpoint_path,
  500. device,
  501. precision,
  502. max_length,
  503. compile=False,
  504. ):
  505. input_queue = queue.Queue()
  506. init_event = threading.Event()
  507. def worker():
  508. model, decode_one_token = load_model(
  509. config_name, checkpoint_path, device, precision, max_length, compile=compile
  510. )
  511. init_event.set()
  512. while True:
  513. item = input_queue.get()
  514. if item is None:
  515. break
  516. kwargs = item["request"]
  517. response_queue = item["response_queue"]
  518. try:
  519. item["success"] = True
  520. for chunk in generate_long(
  521. model=model, decode_one_token=decode_one_token, **kwargs
  522. ):
  523. response_queue.put(chunk)
  524. response_queue.put("done")
  525. except Exception as e:
  526. item["success"] = False
  527. item["response"] = e
  528. response_queue.put("done")
  529. threading.Thread(target=worker, daemon=True).start()
  530. init_event.wait()
  531. return input_queue
  532. @click.command()
  533. @click.option(
  534. "--text",
  535. type=str,
  536. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  537. )
  538. @click.option("--prompt-text", type=str, default=None)
  539. @click.option(
  540. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  541. )
  542. @click.option("--num-samples", type=int, default=1)
  543. @click.option("--max-new-tokens", type=int, default=0)
  544. @click.option("--top-k", type=int, default=None)
  545. @click.option("--top-p", type=float, default=0.7)
  546. @click.option("--repetition-penalty", type=float, default=1.5)
  547. @click.option("--temperature", type=float, default=0.7)
  548. @click.option(
  549. "--checkpoint-path",
  550. type=click.Path(path_type=Path, exists=True),
  551. default="results/text2semantic_400m_finetune/step_000002000.pth",
  552. )
  553. @click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
  554. @click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
  555. @click.option("--compile/--no-compile", default=False)
  556. @click.option("--seed", type=int, default=42)
  557. @click.option("--speaker", type=str, default=None)
  558. @click.option("--half/--no-half", default=False)
  559. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  560. @click.option("--max-length", type=int, default=2048)
  561. @click.option("--chunk-length", type=int, default=30)
  562. def main(
  563. text: str,
  564. prompt_text: Optional[str],
  565. prompt_tokens: Optional[Path],
  566. num_samples: int,
  567. max_new_tokens: int,
  568. top_k: int,
  569. top_p: int,
  570. repetition_penalty: float,
  571. temperature: float,
  572. checkpoint_path: Path,
  573. config_name: str,
  574. tokenizer: str,
  575. compile: bool,
  576. seed: int,
  577. speaker: Optional[str],
  578. half: bool,
  579. iterative_prompt: bool,
  580. max_length: int,
  581. chunk_length: int,
  582. ) -> None:
  583. device = "cuda"
  584. precision = torch.half if half else torch.bfloat16
  585. logger.info("Loading model ...")
  586. t0 = time.time()
  587. model, decode_one_token = load_model(
  588. config_name, checkpoint_path, device, precision, max_length, compile=compile
  589. )
  590. torch.cuda.synchronize()
  591. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  592. prompt_tokens = (
  593. torch.from_numpy(np.load(prompt_tokens)).to(device)
  594. if prompt_tokens is not None
  595. else None
  596. )
  597. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  598. torch.manual_seed(seed)
  599. torch.cuda.manual_seed(seed)
  600. generator = generate_long(
  601. model=model,
  602. device=device,
  603. decode_one_token=decode_one_token,
  604. text=text,
  605. num_samples=num_samples,
  606. max_new_tokens=max_new_tokens,
  607. top_k=top_k,
  608. top_p=top_p,
  609. repetition_penalty=repetition_penalty,
  610. temperature=temperature,
  611. tokenizer=tokenizer,
  612. compile=compile,
  613. speaker=speaker,
  614. iterative_prompt=iterative_prompt,
  615. max_length=max_length,
  616. chunk_length=chunk_length,
  617. prompt_text=prompt_text,
  618. prompt_tokens=prompt_tokens,
  619. )
  620. for idx, codes in enumerate(generator):
  621. np.save(f"codes_{idx}.npy", codes.cpu().numpy())
  622. logger.info(f"Saved codes to codes_{idx}.npy")
  623. if __name__ == "__main__":
  624. main()