generate.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762
  1. import os
  2. import queue
  3. import string
  4. import threading
  5. import time
  6. from pathlib import Path
  7. from typing import Optional, Tuple, Union
  8. import click
  9. import hydra
  10. import numpy as np
  11. import torch
  12. import torch._dynamo.config
  13. import torch._inductor.config
  14. from hydra import compose, initialize
  15. from hydra.utils import instantiate
  16. from loguru import logger
  17. from tqdm import tqdm
  18. from transformers import AutoTokenizer
  19. from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
  20. from fish_speech.text.clean import clean_text
  21. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  22. torch._inductor.config.coordinate_descent_tuning = True
  23. torch._inductor.config.triton.unique_kernel_names = True
  24. if hasattr(torch._inductor.config, "fx_graph_cache"):
  25. # Experimental feature to reduce compilation times, will be on by default in future
  26. torch._inductor.config.fx_graph_cache = True
  27. from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
  28. def multinomial_sample_one_no_sync(
  29. probs_sort,
  30. ): # Does multinomial sampling without a cuda synchronization
  31. q = torch.empty_like(probs_sort).exponential_(1)
  32. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  33. def logits_to_probs(
  34. logits,
  35. previous_tokens: Optional[torch.Tensor] = None,
  36. temperature: torch.Tensor = 1.0,
  37. top_p: torch.Tensor = 1.0,
  38. repetition_penalty: torch.Tensor = 1.0,
  39. ) -> torch.Tensor:
  40. # Apply repetition penalty
  41. if previous_tokens is not None:
  42. previous_tokens = previous_tokens.long()
  43. score = torch.gather(logits, dim=0, index=previous_tokens)
  44. score = torch.where(
  45. score < 0, score * repetition_penalty, score / repetition_penalty
  46. )
  47. logits.scatter_(dim=0, index=previous_tokens, src=score)
  48. # Apply top-p sampling
  49. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  50. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  51. sorted_indices_to_remove = cum_probs > top_p
  52. sorted_indices_to_remove[0] = False # keep at least one option
  53. indices_to_remove = sorted_indices_to_remove.scatter(
  54. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  55. )
  56. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  57. logits = logits / max(temperature, 1e-5)
  58. probs = torch.nn.functional.softmax(logits, dim=-1)
  59. return probs
  60. def sample(
  61. logits,
  62. previous_tokens: Optional[torch.Tensor] = None,
  63. **sampling_kwargs,
  64. ) -> Tuple[torch.Tensor, torch.Tensor]:
  65. probs = logits_to_probs(
  66. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  67. )
  68. idx_next = multinomial_sample_one_no_sync(probs)
  69. return idx_next, probs
  70. def decode_one_token_ar(
  71. model: DualARTransformer,
  72. x: torch.Tensor,
  73. input_pos: torch.Tensor,
  74. previous_tokens: torch.Tensor = None,
  75. **sampling_kwargs,
  76. ) -> torch.Tensor:
  77. x = model.forward_generate(x, input_pos)
  78. codebooks = [
  79. sample(
  80. x.logits,
  81. previous_tokens=None, # Disable repetition penalty for the token codebook
  82. **sampling_kwargs,
  83. )[0]
  84. ]
  85. x = x.hidden_states
  86. # Cleanup the cache
  87. for layer in model.fast_layers:
  88. layer.attention.kv_cache.k_cache.fill_(0)
  89. layer.attention.kv_cache.v_cache.fill_(0)
  90. for codebook_idx in range(model.config.num_codebooks):
  91. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  92. logits = model.forward_generate_fast(x, input_pos)
  93. a = sample(
  94. logits,
  95. previous_tokens=(
  96. previous_tokens[codebook_idx + 1]
  97. if previous_tokens is not None
  98. else None
  99. ),
  100. **sampling_kwargs,
  101. )[0]
  102. x = model.fast_embeddings(a)
  103. codebooks.append(a)
  104. return torch.stack(codebooks, dim=0)
  105. def decode_one_token_naive(
  106. model: NaiveTransformer,
  107. x: torch.Tensor,
  108. input_pos: torch.Tensor,
  109. previous_tokens: torch.Tensor = None,
  110. **sampling_kwargs,
  111. ) -> torch.Tensor:
  112. x = model.forward_generate(x, input_pos)
  113. codebooks = [
  114. sample(
  115. x.token_logits,
  116. previous_tokens=None, # Disable repetition penalty for the token codebook
  117. **sampling_kwargs,
  118. )[0]
  119. ]
  120. for i in range(model.config.num_codebooks):
  121. codebooks.append(
  122. sample(
  123. x.codebook_logits[:, :, i],
  124. previous_tokens=(
  125. previous_tokens[i + 1] if previous_tokens is not None else None
  126. ),
  127. **sampling_kwargs,
  128. )[0]
  129. )
  130. return torch.stack(codebooks, dim=0)
  131. def decode_n_tokens(
  132. model: NaiveTransformer,
  133. cur_token: torch.Tensor,
  134. input_pos: torch.Tensor,
  135. num_new_tokens: int,
  136. eos_token_id: int = 2,
  137. im_end_id: int = 4,
  138. decode_one_token=decode_one_token_naive,
  139. **sampling_kwargs,
  140. ):
  141. previous_tokens = torch.zeros(
  142. (model.config.num_codebooks + 1, model.config.max_seq_len),
  143. dtype=torch.int,
  144. device=cur_token.device,
  145. )
  146. for i in tqdm(range(num_new_tokens)):
  147. # We need to get windowed repeat penalty
  148. win_size = 16
  149. if i < win_size:
  150. window = previous_tokens[:, :win_size]
  151. else:
  152. window = previous_tokens[:, i - win_size : i]
  153. with torch.backends.cuda.sdp_kernel(
  154. enable_flash=False, enable_mem_efficient=False, enable_math=True
  155. ): # Actually better for Inductor to codegen attention here
  156. next_token = decode_one_token(
  157. model=model,
  158. x=cur_token,
  159. input_pos=input_pos,
  160. previous_tokens=window,
  161. **sampling_kwargs,
  162. )
  163. input_pos += 1
  164. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  165. previous_tokens[:, i : i + 1] = next_token.view(
  166. model.config.num_codebooks + 1, -1
  167. )
  168. if (
  169. cur_token[0, 0, -1] == eos_token_id
  170. or cur_token[0, 0, -1] == im_end_id
  171. or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
  172. ):
  173. break
  174. return previous_tokens[:, : i + 1]
  175. @torch.no_grad()
  176. @torch.inference_mode()
  177. def generate(
  178. *,
  179. model: NaiveTransformer,
  180. prompt: torch.Tensor,
  181. max_new_tokens: int,
  182. eos_token_id: int = 2,
  183. im_end_id: int = 4,
  184. decode_one_token=decode_one_token_naive,
  185. **sampling_kwargs,
  186. ) -> torch.Tensor:
  187. """
  188. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  189. """
  190. # create an empty tensor of the expected final shape and fill in the current tokens
  191. T = prompt.size(1)
  192. if max_new_tokens:
  193. if T + max_new_tokens > model.config.max_seq_len:
  194. max_new_tokens = model.config.max_seq_len - T
  195. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  196. T_new = T + max_new_tokens
  197. else:
  198. T_new = model.config.max_seq_len
  199. max_new_tokens = T_new - T
  200. device, dtype = prompt.device, prompt.dtype
  201. with torch.device(device):
  202. model.setup_caches(
  203. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  204. )
  205. codebook_dim = 1 + model.config.num_codebooks
  206. # create an empty tensor of the expected final shape and fill in the current tokens
  207. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  208. empty[:, :T] = prompt
  209. seq = empty
  210. input_pos = torch.arange(0, T, device=device)
  211. # Use non-accelerated version for now, to avoid compilation overhead
  212. prefill_decode = (
  213. decode_one_token_naive
  214. if isinstance(model, NaiveTransformer)
  215. else decode_one_token_ar
  216. )
  217. next_token = prefill_decode(
  218. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  219. )
  220. seq[:, T : T + 1] = next_token
  221. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  222. x = decode_n_tokens(
  223. model,
  224. next_token.view(1, codebook_dim, -1),
  225. input_pos,
  226. max_new_tokens - 1,
  227. eos_token_id=eos_token_id,
  228. im_end_id=im_end_id,
  229. decode_one_token=decode_one_token,
  230. **sampling_kwargs,
  231. )
  232. # x = torch.cat(generated_tokens, dim=1)
  233. seq = seq[:, : T + 1 + x.size(1)]
  234. seq[:, T + 1 :] = x
  235. return seq
  236. def encode_tokens(
  237. tokenizer,
  238. string,
  239. bos=True,
  240. device="cuda",
  241. prompt_tokens=None,
  242. speaker=None,
  243. num_codebooks=4,
  244. ):
  245. string = clean_text(string)
  246. if speaker is None:
  247. speaker = "assistant"
  248. string = (
  249. f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
  250. )
  251. if bos:
  252. string = f"<|begin_of_sequence|>{string}"
  253. new_tokens = tokenizer.encode(
  254. string,
  255. add_special_tokens=False,
  256. max_length=10**6,
  257. truncation=False,
  258. )
  259. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  260. # Codebooks
  261. zeros = (
  262. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  263. * CODEBOOK_PAD_TOKEN_ID
  264. )
  265. prompt = torch.cat((tokens, zeros), dim=0)
  266. if prompt_tokens is None:
  267. return prompt
  268. # Get prompt tokens
  269. if prompt_tokens.ndim == 3:
  270. assert (
  271. prompt_tokens.shape[0] == 1
  272. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  273. prompt_tokens = prompt_tokens[0]
  274. assert prompt_tokens.ndim == 2
  275. data = prompt_tokens + 2
  276. if prompt_tokens.shape[0] > num_codebooks:
  277. logger.warning(
  278. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  279. )
  280. data = data[:num_codebooks]
  281. # Add eos token for each codebook
  282. data = torch.cat(
  283. (
  284. data,
  285. torch.ones((data.size(0), 1), dtype=torch.int, device=device)
  286. * CODEBOOK_EOS_TOKEN_ID,
  287. ),
  288. dim=1,
  289. )
  290. # Since 1.0, we use <|semantic|>
  291. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  292. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  293. main_token_ids = (
  294. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  295. )
  296. main_token_ids[0, -1] = end_token_id
  297. data = torch.cat((main_token_ids, data), dim=0)
  298. prompt = torch.cat((prompt, data), dim=1)
  299. return prompt
  300. def load_model(
  301. config_name, checkpoint_path, device, precision, max_length, compile=False
  302. ):
  303. hydra.core.global_hydra.GlobalHydra.instance().clear()
  304. with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
  305. cfg = compose(
  306. config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
  307. )
  308. model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
  309. if "int8" in str(checkpoint_path):
  310. logger.info("Using int8 weight-only quantization!")
  311. from .quantize import WeightOnlyInt8QuantHandler
  312. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  313. model = simple_quantizer.convert_for_runtime()
  314. if "int4" in str(checkpoint_path):
  315. logger.info("Using int4 quantization!")
  316. path_comps = checkpoint_path.name.split(".")
  317. assert path_comps[-2].startswith("g")
  318. groupsize = int(path_comps[-2][1:])
  319. from .quantize import WeightOnlyInt4QuantHandler
  320. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  321. model = simple_quantizer.convert_for_runtime()
  322. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  323. if "state_dict" in checkpoint:
  324. checkpoint = checkpoint["state_dict"]
  325. if any(k.startswith("model.") for k in checkpoint):
  326. checkpoint = {
  327. k.replace("model.", ""): v
  328. for k, v in checkpoint.items()
  329. if k.startswith("model.")
  330. }
  331. model.load_state_dict(checkpoint, assign=True)
  332. model = model.to(device=device, dtype=precision)
  333. logger.info("Restored model from checkpoint")
  334. if isinstance(model, DualARTransformer):
  335. decode_one_token = decode_one_token_ar
  336. logger.info("Using DualARTransformer")
  337. else:
  338. decode_one_token = decode_one_token_naive
  339. logger.info("Using NaiveTransformer")
  340. if compile:
  341. logger.info("Compiling function...")
  342. decode_one_token = torch.compile(
  343. decode_one_token, mode="reduce-overhead", fullgraph=True
  344. )
  345. return model.eval(), decode_one_token
  346. def split_text(text, min_length):
  347. text = clean_text(text)
  348. segments = []
  349. curr = ""
  350. def clean_add(curr):
  351. curr = curr.strip()
  352. if curr and not all(c.isspace() or c in string.punctuation for c in curr):
  353. segments.append(curr)
  354. for char in text:
  355. curr += char
  356. if char not in [".", "!", "?"]:
  357. continue
  358. if len(curr) >= min_length:
  359. clean_add(curr)
  360. curr = ""
  361. clean_add(curr)
  362. return segments
  363. def generate_long(
  364. *,
  365. model,
  366. tokenizer: callable,
  367. device: str | torch.device,
  368. decode_one_token: callable,
  369. text: str,
  370. num_samples: int = 1,
  371. max_new_tokens: int = 0,
  372. top_p: int = 0.7,
  373. repetition_penalty: float = 1.5,
  374. temperature: float = 0.7,
  375. compile: bool = False,
  376. iterative_prompt: bool = True,
  377. max_length: int = 2048,
  378. chunk_length: int = 30,
  379. speaker: Optional[str] = None,
  380. prompt_text: Optional[str] = None,
  381. prompt_tokens: Optional[torch.Tensor] = None,
  382. is_streaming: bool = False,
  383. ):
  384. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  385. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  386. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  387. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  388. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  389. use_prompt = prompt_text is not None and prompt_tokens is not None
  390. encoded = []
  391. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  392. if use_prompt:
  393. encoded_prompts = encode_tokens(
  394. tokenizer,
  395. prompt_text,
  396. prompt_tokens=prompt_tokens,
  397. bos=True,
  398. device=device,
  399. speaker=speaker,
  400. num_codebooks=model.config.num_codebooks,
  401. )
  402. for idx, text in enumerate(texts):
  403. encoded.append(
  404. encode_tokens(
  405. tokenizer,
  406. string=text,
  407. bos=idx == 0 and not use_prompt,
  408. device=device,
  409. speaker=speaker,
  410. num_codebooks=model.config.num_codebooks,
  411. )
  412. )
  413. logger.info(f"Encoded text: {text}")
  414. # Move temperature, top_p, repetition_penalty to device
  415. # This is important so that changing params doesn't trigger recompile
  416. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  417. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  418. repetition_penalty = torch.tensor(
  419. repetition_penalty, device=device, dtype=torch.float
  420. )
  421. for sample_idx in range(num_samples):
  422. if torch.cuda.is_available():
  423. torch.cuda.synchronize()
  424. global_encoded = []
  425. all_codes = []
  426. seg_idx = 0
  427. while seg_idx < len(encoded):
  428. logger.info(
  429. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  430. )
  431. seg = encoded[seg_idx]
  432. global_encoded.append(seg)
  433. lengths = reversed([seg.size(1) for seg in global_encoded])
  434. # Pick last 2000 tokens
  435. count = 0
  436. for i, length in enumerate(lengths):
  437. count += length
  438. if count + length > max_length - 1024:
  439. break
  440. if i != 0 and i % 2 == 0:
  441. i -= 1
  442. # Rotate the list, always make sure first segment is included to avoid drift
  443. if i < len(global_encoded) - 2:
  444. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  445. else:
  446. partial_encoded = global_encoded
  447. if use_prompt:
  448. partial_encoded = [encoded_prompts] + partial_encoded
  449. cat_encoded = torch.cat(partial_encoded, dim=1)
  450. prompt_length = cat_encoded.size(1)
  451. t0 = time.perf_counter()
  452. y = generate(
  453. model=model,
  454. prompt=cat_encoded,
  455. max_new_tokens=max_new_tokens,
  456. eos_token_id=tokenizer.eos_token_id,
  457. im_end_id=im_end_id,
  458. decode_one_token=decode_one_token,
  459. temperature=temperature,
  460. top_p=top_p,
  461. repetition_penalty=repetition_penalty,
  462. )
  463. if sample_idx == 0 and seg_idx == 0 and compile:
  464. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  465. if torch.cuda.is_available():
  466. torch.cuda.synchronize()
  467. t = time.perf_counter() - t0
  468. tokens_generated = y.size(1) - prompt_length
  469. tokens_sec = tokens_generated / t
  470. logger.info(
  471. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  472. )
  473. logger.info(
  474. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  475. )
  476. if torch.cuda.is_available():
  477. logger.info(
  478. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  479. )
  480. # Put the generated tokens
  481. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  482. codes = y[1:, prompt_length:-2].clone()
  483. codes = codes - 2
  484. assert (codes >= 0).all(), f"Negative code found"
  485. decoded = y[:, prompt_length:-1].clone()
  486. if decoded[0, -1] != im_end_id: # <im_end>
  487. val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
  488. decoded = torch.cat(
  489. (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
  490. )
  491. # But for global encoding, we should keep the <im_end> token
  492. global_encoded.append(decoded)
  493. if is_streaming:
  494. assert (codes >= 0).all(), f"Negative code found: {codes}"
  495. yield codes
  496. else:
  497. all_codes.append(codes)
  498. seg_idx += 1
  499. if is_streaming:
  500. # This indicates the end of the current sample
  501. yield "next"
  502. else:
  503. all_codes = torch.cat(all_codes, dim=1)
  504. assert (all_codes >= 0).all(), f"Negative code found: {codes}"
  505. yield all_codes
  506. def launch_thread_safe_queue(
  507. config_name,
  508. checkpoint_path,
  509. device,
  510. precision,
  511. max_length,
  512. compile=False,
  513. ):
  514. input_queue = queue.Queue()
  515. init_event = threading.Event()
  516. def worker():
  517. model, decode_one_token = load_model(
  518. config_name, checkpoint_path, device, precision, max_length, compile=compile
  519. )
  520. init_event.set()
  521. while True:
  522. item = input_queue.get()
  523. if item is None:
  524. break
  525. kwargs = item["request"]
  526. response_queue = item["response_queue"]
  527. try:
  528. item["success"] = True
  529. for chunk in generate_long(
  530. model=model, decode_one_token=decode_one_token, **kwargs
  531. ):
  532. response_queue.put(chunk)
  533. response_queue.put("done")
  534. except Exception as e:
  535. item["success"] = False
  536. item["response"] = e
  537. response_queue.put("done")
  538. threading.Thread(target=worker, daemon=True).start()
  539. init_event.wait()
  540. return input_queue
  541. @click.command()
  542. @click.option(
  543. "--text",
  544. type=str,
  545. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  546. )
  547. @click.option("--prompt-text", type=str, default=None)
  548. @click.option(
  549. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  550. )
  551. @click.option("--num-samples", type=int, default=1)
  552. @click.option("--max-new-tokens", type=int, default=0)
  553. @click.option("--top-p", type=float, default=0.7)
  554. @click.option("--repetition-penalty", type=float, default=1.5)
  555. @click.option("--temperature", type=float, default=0.7)
  556. @click.option(
  557. "--checkpoint-path",
  558. type=click.Path(path_type=Path, exists=True),
  559. default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
  560. )
  561. @click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
  562. @click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
  563. @click.option("--compile/--no-compile", default=False)
  564. @click.option("--seed", type=int, default=42)
  565. @click.option("--speaker", type=str, default=None)
  566. @click.option("--half/--no-half", default=False)
  567. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  568. @click.option("--max-length", type=int, default=2048)
  569. @click.option("--chunk-length", type=int, default=30)
  570. def main(
  571. text: str,
  572. prompt_text: Optional[str],
  573. prompt_tokens: Optional[Path],
  574. num_samples: int,
  575. max_new_tokens: int,
  576. top_p: int,
  577. repetition_penalty: float,
  578. temperature: float,
  579. checkpoint_path: Path,
  580. config_name: str,
  581. tokenizer: str,
  582. compile: bool,
  583. seed: int,
  584. speaker: Optional[str],
  585. half: bool,
  586. iterative_prompt: bool,
  587. max_length: int,
  588. chunk_length: int,
  589. ) -> None:
  590. device = "cuda"
  591. precision = torch.half if half else torch.bfloat16
  592. logger.info("Loading model ...")
  593. t0 = time.time()
  594. model, decode_one_token = load_model(
  595. config_name, checkpoint_path, device, precision, max_length, compile=compile
  596. )
  597. if torch.cuda.is_available():
  598. torch.cuda.synchronize()
  599. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  600. prompt_tokens = (
  601. torch.from_numpy(np.load(prompt_tokens)).to(device)
  602. if prompt_tokens is not None
  603. else None
  604. )
  605. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  606. torch.manual_seed(seed)
  607. if torch.cuda.is_available():
  608. torch.cuda.manual_seed(seed)
  609. generator = generate_long(
  610. model=model,
  611. device=device,
  612. decode_one_token=decode_one_token,
  613. text=text,
  614. num_samples=num_samples,
  615. max_new_tokens=max_new_tokens,
  616. top_p=top_p,
  617. repetition_penalty=repetition_penalty,
  618. temperature=temperature,
  619. tokenizer=tokenizer,
  620. compile=compile,
  621. speaker=speaker,
  622. iterative_prompt=iterative_prompt,
  623. max_length=max_length,
  624. chunk_length=chunk_length,
  625. prompt_text=prompt_text,
  626. prompt_tokens=prompt_tokens,
  627. )
  628. for idx, codes in enumerate(generator):
  629. np.save(f"codes_{idx}.npy", codes.cpu().numpy())
  630. logger.info(f"Saved codes to codes_{idx}.npy")
  631. if __name__ == "__main__":
  632. main()