generate.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from pathlib import Path
  6. from typing import Optional, Tuple, Union
  7. import click
  8. import hydra
  9. import numpy as np
  10. import torch
  11. import torch._dynamo.config
  12. import torch._inductor.config
  13. from hydra import compose, initialize
  14. from hydra.utils import instantiate
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from transformers import AutoTokenizer
  18. from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
  19. from fish_speech.text.clean import clean_text
  20. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  21. torch._inductor.config.coordinate_descent_tuning = True
  22. torch._inductor.config.triton.unique_kernel_names = True
  23. if hasattr(torch._inductor.config, "fx_graph_cache"):
  24. # Experimental feature to reduce compilation times, will be on by default in future
  25. torch._inductor.config.fx_graph_cache = True
  26. from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
  27. def multinomial_sample_one_no_sync(
  28. probs_sort,
  29. ): # Does multinomial sampling without a cuda synchronization
  30. q = torch.empty_like(probs_sort).exponential_(1)
  31. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  32. def logits_to_probs(
  33. logits,
  34. previous_tokens: Optional[torch.Tensor] = None,
  35. temperature: torch.Tensor = 1.0,
  36. top_p: torch.Tensor = 1.0,
  37. repetition_penalty: torch.Tensor = 1.0,
  38. ) -> torch.Tensor:
  39. # Apply repetition penalty
  40. if previous_tokens is not None:
  41. previous_tokens = previous_tokens.long()
  42. score = torch.gather(logits, dim=0, index=previous_tokens)
  43. score = torch.where(
  44. score < 0, score * repetition_penalty, score / repetition_penalty
  45. )
  46. logits.scatter_(dim=0, index=previous_tokens, src=score)
  47. # Apply top-p sampling
  48. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  49. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  50. sorted_indices_to_remove = cum_probs > top_p
  51. sorted_indices_to_remove[0] = False # keep at least one option
  52. indices_to_remove = sorted_indices_to_remove.scatter(
  53. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  54. )
  55. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  56. logits = logits / max(temperature, 1e-5)
  57. probs = torch.nn.functional.softmax(logits, dim=-1)
  58. return probs
  59. def sample(
  60. logits,
  61. previous_tokens: Optional[torch.Tensor] = None,
  62. **sampling_kwargs,
  63. ) -> Tuple[torch.Tensor, torch.Tensor]:
  64. probs = logits_to_probs(
  65. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  66. )
  67. idx_next = multinomial_sample_one_no_sync(probs)
  68. return idx_next, probs
  69. def decode_one_token_ar(
  70. model: DualARTransformer,
  71. x: torch.Tensor,
  72. input_pos: torch.Tensor,
  73. previous_tokens: torch.Tensor = None,
  74. **sampling_kwargs,
  75. ) -> torch.Tensor:
  76. x = model.forward_generate(x, input_pos)
  77. codebooks = [
  78. sample(
  79. x.logits,
  80. previous_tokens=None, # Disable repetition penalty for the token codebook
  81. **sampling_kwargs,
  82. )[0]
  83. ]
  84. x = x.hidden_states
  85. # Cleanup the cache
  86. for layer in model.fast_layers:
  87. layer.attention.kv_cache.k_cache.fill_(0)
  88. layer.attention.kv_cache.v_cache.fill_(0)
  89. for codebook_idx in range(model.config.num_codebooks):
  90. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  91. logits = model.forward_generate_fast(x, input_pos)
  92. a = sample(
  93. logits,
  94. previous_tokens=(
  95. previous_tokens[codebook_idx + 1]
  96. if previous_tokens is not None
  97. else None
  98. ),
  99. **sampling_kwargs,
  100. )[0]
  101. x = model.fast_embeddings(a)
  102. codebooks.append(a)
  103. return torch.stack(codebooks, dim=0)
  104. def decode_one_token_naive(
  105. model: NaiveTransformer,
  106. x: torch.Tensor,
  107. input_pos: torch.Tensor,
  108. previous_tokens: torch.Tensor = None,
  109. **sampling_kwargs,
  110. ) -> torch.Tensor:
  111. x = model.forward_generate(x, input_pos)
  112. codebooks = [
  113. sample(
  114. x.token_logits,
  115. previous_tokens=None, # Disable repetition penalty for the token codebook
  116. **sampling_kwargs,
  117. )[0]
  118. ]
  119. for i in range(model.config.num_codebooks):
  120. codebooks.append(
  121. sample(
  122. x.codebook_logits[:, :, i],
  123. previous_tokens=(
  124. previous_tokens[i + 1] if previous_tokens is not None else None
  125. ),
  126. **sampling_kwargs,
  127. )[0]
  128. )
  129. return torch.stack(codebooks, dim=0)
  130. def decode_n_tokens(
  131. model: NaiveTransformer,
  132. cur_token: torch.Tensor,
  133. input_pos: torch.Tensor,
  134. num_new_tokens: int,
  135. eos_token_id: int = 2,
  136. im_end_id: int = 4,
  137. decode_one_token=decode_one_token_naive,
  138. **sampling_kwargs,
  139. ):
  140. previous_tokens = torch.zeros(
  141. (model.config.num_codebooks + 1, model.config.max_seq_len),
  142. dtype=torch.int,
  143. device=cur_token.device,
  144. )
  145. for i in tqdm(range(num_new_tokens)):
  146. # We need to get windowed repeat penalty
  147. win_size = 16
  148. if i < win_size:
  149. window = previous_tokens[:, :win_size]
  150. else:
  151. window = previous_tokens[:, i - win_size : i]
  152. with torch.backends.cuda.sdp_kernel(
  153. enable_flash=False, enable_mem_efficient=False, enable_math=True
  154. ): # Actually better for Inductor to codegen attention here
  155. next_token = decode_one_token(
  156. model=model,
  157. x=cur_token,
  158. input_pos=input_pos,
  159. previous_tokens=window,
  160. **sampling_kwargs,
  161. )
  162. input_pos += 1
  163. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  164. previous_tokens[:, i : i + 1] = next_token.view(
  165. model.config.num_codebooks + 1, -1
  166. )
  167. if (
  168. cur_token[0, 0, -1] == eos_token_id
  169. or cur_token[0, 0, -1] == im_end_id
  170. or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
  171. ):
  172. break
  173. return previous_tokens[:, : i + 1]
  174. @torch.no_grad()
  175. @torch.inference_mode()
  176. def generate(
  177. *,
  178. model: NaiveTransformer,
  179. prompt: torch.Tensor,
  180. max_new_tokens: int,
  181. eos_token_id: int = 2,
  182. im_end_id: int = 4,
  183. decode_one_token=decode_one_token_naive,
  184. **sampling_kwargs,
  185. ) -> torch.Tensor:
  186. """
  187. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  188. """
  189. # create an empty tensor of the expected final shape and fill in the current tokens
  190. T = prompt.size(1)
  191. if max_new_tokens:
  192. if T + max_new_tokens > model.config.max_seq_len:
  193. max_new_tokens = model.config.max_seq_len - T
  194. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  195. T_new = T + max_new_tokens
  196. else:
  197. T_new = model.config.max_seq_len
  198. max_new_tokens = T_new - T
  199. device, dtype = prompt.device, prompt.dtype
  200. with torch.device(device):
  201. model.setup_caches(
  202. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  203. )
  204. codebook_dim = 1 + model.config.num_codebooks
  205. # create an empty tensor of the expected final shape and fill in the current tokens
  206. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  207. empty[:, :T] = prompt
  208. seq = empty
  209. input_pos = torch.arange(0, T, device=device)
  210. # Use non-accelerated version for now, to avoid compilation overhead
  211. prefill_decode = (
  212. decode_one_token_naive
  213. if isinstance(model, NaiveTransformer)
  214. else decode_one_token_ar
  215. )
  216. next_token = prefill_decode(
  217. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  218. )
  219. seq[:, T : T + 1] = next_token
  220. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  221. x = decode_n_tokens(
  222. model,
  223. next_token.view(1, codebook_dim, -1),
  224. input_pos,
  225. max_new_tokens - 1,
  226. eos_token_id=eos_token_id,
  227. im_end_id=im_end_id,
  228. decode_one_token=decode_one_token,
  229. **sampling_kwargs,
  230. )
  231. # x = torch.cat(generated_tokens, dim=1)
  232. seq = seq[:, : T + 1 + x.size(1)]
  233. seq[:, T + 1 :] = x
  234. return seq
  235. def encode_tokens(
  236. tokenizer,
  237. string,
  238. bos=True,
  239. device="cuda",
  240. prompt_tokens=None,
  241. speaker=None,
  242. num_codebooks=4,
  243. ):
  244. string = clean_text(string)
  245. if speaker is None:
  246. speaker = "assistant"
  247. string = (
  248. f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
  249. )
  250. if bos:
  251. string = f"<|begin_of_sequence|>{string}"
  252. new_tokens = tokenizer.encode(
  253. string,
  254. add_special_tokens=False,
  255. max_length=10**6,
  256. truncation=False,
  257. )
  258. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  259. # Codebooks
  260. zeros = (
  261. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  262. * CODEBOOK_PAD_TOKEN_ID
  263. )
  264. prompt = torch.cat((tokens, zeros), dim=0)
  265. if prompt_tokens is None:
  266. return prompt
  267. # Get prompt tokens
  268. if prompt_tokens.ndim == 3:
  269. assert (
  270. prompt_tokens.shape[0] == 1
  271. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  272. prompt_tokens = prompt_tokens[0]
  273. assert prompt_tokens.ndim == 2
  274. data = prompt_tokens + 2
  275. if prompt_tokens.shape[0] > num_codebooks:
  276. logger.warning(
  277. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  278. )
  279. data = data[:num_codebooks]
  280. # Add eos token for each codebook
  281. data = torch.cat(
  282. (
  283. data,
  284. torch.ones((data.size(0), 1), dtype=torch.int, device=device)
  285. * CODEBOOK_EOS_TOKEN_ID,
  286. ),
  287. dim=1,
  288. )
  289. # Since 1.0, we use <|semantic|>
  290. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  291. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  292. main_token_ids = (
  293. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  294. )
  295. main_token_ids[0, -1] = end_token_id
  296. data = torch.cat((main_token_ids, data), dim=0)
  297. prompt = torch.cat((prompt, data), dim=1)
  298. return prompt
  299. def load_model(
  300. config_name, checkpoint_path, device, precision, max_length, compile=False
  301. ):
  302. hydra.core.global_hydra.GlobalHydra.instance().clear()
  303. with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
  304. cfg = compose(
  305. config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
  306. )
  307. model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
  308. if "int8" in str(checkpoint_path):
  309. logger.info("Using int8 weight-only quantization!")
  310. from quantize import WeightOnlyInt8QuantHandler
  311. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  312. model = simple_quantizer.convert_for_runtime()
  313. if "int4" in str(checkpoint_path):
  314. logger.info("Using int4 quantization!")
  315. path_comps = checkpoint_path.name.split(".")
  316. assert path_comps[-2].startswith("g")
  317. groupsize = int(path_comps[-2][1:])
  318. from quantize import WeightOnlyInt4QuantHandler
  319. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  320. model = simple_quantizer.convert_for_runtime()
  321. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  322. if "state_dict" in checkpoint:
  323. checkpoint = checkpoint["state_dict"]
  324. if any(k.startswith("model.") for k in checkpoint):
  325. checkpoint = {
  326. k.replace("model.", ""): v
  327. for k, v in checkpoint.items()
  328. if k.startswith("model.")
  329. }
  330. model.load_state_dict(checkpoint, assign=True)
  331. model = model.to(device=device, dtype=precision)
  332. logger.info("Restored model from checkpoint")
  333. if isinstance(model, DualARTransformer):
  334. decode_one_token = decode_one_token_ar
  335. logger.info("Using DualARTransformer")
  336. else:
  337. decode_one_token = decode_one_token_naive
  338. logger.info("Using NaiveTransformer")
  339. if compile:
  340. logger.info("Compiling function...")
  341. decode_one_token = torch.compile(
  342. decode_one_token, mode="reduce-overhead", fullgraph=True
  343. )
  344. return model.eval(), decode_one_token
  345. def split_text(text, min_length):
  346. text = clean_text(text)
  347. segments = []
  348. curr = ""
  349. for char in text:
  350. curr += char
  351. if char not in [".", "!", "?"]:
  352. continue
  353. if len(curr) >= min_length:
  354. segments.append(curr)
  355. curr = ""
  356. if curr:
  357. segments.append(curr)
  358. return segments
  359. def generate_long(
  360. *,
  361. model,
  362. tokenizer: callable,
  363. device: str | torch.device,
  364. decode_one_token: callable,
  365. text: str,
  366. num_samples: int = 1,
  367. max_new_tokens: int = 0,
  368. top_p: int = 0.7,
  369. repetition_penalty: float = 1.5,
  370. temperature: float = 0.7,
  371. compile: bool = False,
  372. iterative_prompt: bool = True,
  373. max_length: int = 2048,
  374. chunk_length: int = 30,
  375. speaker: Optional[str] = None,
  376. prompt_text: Optional[str] = None,
  377. prompt_tokens: Optional[torch.Tensor] = None,
  378. is_streaming: bool = False,
  379. ):
  380. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  381. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  382. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  383. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  384. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  385. use_prompt = prompt_text is not None and prompt_tokens is not None
  386. encoded = []
  387. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  388. if use_prompt:
  389. encoded_prompts = encode_tokens(
  390. tokenizer,
  391. prompt_text,
  392. prompt_tokens=prompt_tokens,
  393. bos=True,
  394. device=device,
  395. speaker=speaker,
  396. num_codebooks=model.config.num_codebooks,
  397. )
  398. for idx, text in enumerate(texts):
  399. encoded.append(
  400. encode_tokens(
  401. tokenizer,
  402. string=text,
  403. bos=idx == 0 and not use_prompt,
  404. device=device,
  405. speaker=speaker,
  406. num_codebooks=model.config.num_codebooks,
  407. )
  408. )
  409. logger.info(f"Encoded text: {text}")
  410. # Move temperature, top_p, repetition_penalty to device
  411. # This is important so that changing params doesn't trigger recompile
  412. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  413. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  414. repetition_penalty = torch.tensor(
  415. repetition_penalty, device=device, dtype=torch.float
  416. )
  417. for sample_idx in range(num_samples):
  418. if torch.cuda.is_available():
  419. torch.cuda.synchronize()
  420. global_encoded = []
  421. all_codes = []
  422. seg_idx = 0
  423. while seg_idx < len(encoded):
  424. logger.info(
  425. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  426. )
  427. seg = encoded[seg_idx]
  428. global_encoded.append(seg)
  429. lengths = reversed([seg.size(1) for seg in global_encoded])
  430. # Pick last 2000 tokens
  431. count = 0
  432. for i, length in enumerate(lengths):
  433. count += length
  434. if count + length > max_length - 1024:
  435. break
  436. if i != 0 and i % 2 == 0:
  437. i -= 1
  438. # Rotate the list, always make sure first segment is included to avoid drift
  439. if i < len(global_encoded) - 2:
  440. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  441. else:
  442. partial_encoded = global_encoded
  443. if use_prompt:
  444. partial_encoded = [encoded_prompts] + partial_encoded
  445. cat_encoded = torch.cat(partial_encoded, dim=1)
  446. prompt_length = cat_encoded.size(1)
  447. t0 = time.perf_counter()
  448. y = generate(
  449. model=model,
  450. prompt=cat_encoded,
  451. max_new_tokens=max_new_tokens,
  452. eos_token_id=tokenizer.eos_token_id,
  453. im_end_id=im_end_id,
  454. decode_one_token=decode_one_token,
  455. temperature=temperature,
  456. top_p=top_p,
  457. repetition_penalty=repetition_penalty,
  458. )
  459. if sample_idx == 0 and seg_idx == 0 and compile:
  460. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  461. if torch.cuda.is_available():
  462. torch.cuda.synchronize()
  463. t = time.perf_counter() - t0
  464. tokens_generated = y.size(1) - prompt_length
  465. tokens_sec = tokens_generated / t
  466. logger.info(
  467. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  468. )
  469. logger.info(
  470. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  471. )
  472. if torch.cuda.is_available():
  473. logger.info(
  474. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  475. )
  476. # Put the generated tokens
  477. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  478. codes = y[1:, prompt_length:-2].clone()
  479. codes = codes - 2
  480. assert (codes >= 0).all(), f"Negative code found"
  481. decoded = y[:, prompt_length:-1].clone()
  482. if decoded[0, -1] != im_end_id: # <im_end>
  483. val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
  484. decoded = torch.cat(
  485. (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
  486. )
  487. # But for global encoding, we should keep the <im_end> token
  488. global_encoded.append(decoded)
  489. if is_streaming:
  490. assert (codes >= 0).all(), f"Negative code found: {codes}"
  491. yield codes
  492. else:
  493. all_codes.append(codes)
  494. seg_idx += 1
  495. if is_streaming:
  496. # This indicates the end of the current sample
  497. yield "next"
  498. else:
  499. all_codes = torch.cat(all_codes, dim=1)
  500. assert (all_codes >= 0).all(), f"Negative code found: {codes}"
  501. yield all_codes
  502. def launch_thread_safe_queue(
  503. config_name,
  504. checkpoint_path,
  505. device,
  506. precision,
  507. max_length,
  508. compile=False,
  509. ):
  510. input_queue = queue.Queue()
  511. init_event = threading.Event()
  512. def worker():
  513. model, decode_one_token = load_model(
  514. config_name, checkpoint_path, device, precision, max_length, compile=compile
  515. )
  516. init_event.set()
  517. while True:
  518. item = input_queue.get()
  519. if item is None:
  520. break
  521. kwargs = item["request"]
  522. response_queue = item["response_queue"]
  523. try:
  524. item["success"] = True
  525. for chunk in generate_long(
  526. model=model, decode_one_token=decode_one_token, **kwargs
  527. ):
  528. response_queue.put(chunk)
  529. response_queue.put("done")
  530. except Exception as e:
  531. item["success"] = False
  532. item["response"] = e
  533. response_queue.put("done")
  534. threading.Thread(target=worker, daemon=True).start()
  535. init_event.wait()
  536. return input_queue
  537. @click.command()
  538. @click.option(
  539. "--text",
  540. type=str,
  541. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  542. )
  543. @click.option("--prompt-text", type=str, default=None)
  544. @click.option(
  545. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  546. )
  547. @click.option("--num-samples", type=int, default=1)
  548. @click.option("--max-new-tokens", type=int, default=0)
  549. @click.option("--top-p", type=float, default=0.7)
  550. @click.option("--repetition-penalty", type=float, default=1.5)
  551. @click.option("--temperature", type=float, default=0.7)
  552. @click.option(
  553. "--checkpoint-path",
  554. type=click.Path(path_type=Path, exists=True),
  555. default="results/text2semantic_400m_finetune/step_000002000.pth",
  556. )
  557. @click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
  558. @click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
  559. @click.option("--compile/--no-compile", default=False)
  560. @click.option("--seed", type=int, default=42)
  561. @click.option("--speaker", type=str, default=None)
  562. @click.option("--half/--no-half", default=False)
  563. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  564. @click.option("--max-length", type=int, default=2048)
  565. @click.option("--chunk-length", type=int, default=30)
  566. def main(
  567. text: str,
  568. prompt_text: Optional[str],
  569. prompt_tokens: Optional[Path],
  570. num_samples: int,
  571. max_new_tokens: int,
  572. top_p: int,
  573. repetition_penalty: float,
  574. temperature: float,
  575. checkpoint_path: Path,
  576. config_name: str,
  577. tokenizer: str,
  578. compile: bool,
  579. seed: int,
  580. speaker: Optional[str],
  581. half: bool,
  582. iterative_prompt: bool,
  583. max_length: int,
  584. chunk_length: int,
  585. ) -> None:
  586. device = "cuda"
  587. precision = torch.half if half else torch.bfloat16
  588. logger.info("Loading model ...")
  589. t0 = time.time()
  590. model, decode_one_token = load_model(
  591. config_name, checkpoint_path, device, precision, max_length, compile=compile
  592. )
  593. if torch.cuda.is_available():
  594. torch.cuda.synchronize()
  595. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  596. prompt_tokens = (
  597. torch.from_numpy(np.load(prompt_tokens)).to(device)
  598. if prompt_tokens is not None
  599. else None
  600. )
  601. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  602. torch.manual_seed(seed)
  603. if torch.cuda.is_available():
  604. torch.cuda.manual_seed(seed)
  605. generator = generate_long(
  606. model=model,
  607. device=device,
  608. decode_one_token=decode_one_token,
  609. text=text,
  610. num_samples=num_samples,
  611. max_new_tokens=max_new_tokens,
  612. top_p=top_p,
  613. repetition_penalty=repetition_penalty,
  614. temperature=temperature,
  615. tokenizer=tokenizer,
  616. compile=compile,
  617. speaker=speaker,
  618. iterative_prompt=iterative_prompt,
  619. max_length=max_length,
  620. chunk_length=chunk_length,
  621. prompt_text=prompt_text,
  622. prompt_tokens=prompt_tokens,
  623. )
  624. for idx, codes in enumerate(generator):
  625. np.save(f"codes_{idx}.npy", codes.cpu().numpy())
  626. logger.info(f"Saved codes to codes_{idx}.npy")
  627. if __name__ == "__main__":
  628. main()