generate.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755
  1. import os
  2. import queue
  3. import string
  4. import threading
  5. import time
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import hydra
  11. import numpy as np
  12. import torch
  13. import torch._dynamo.config
  14. import torch._inductor.config
  15. from hydra import compose, initialize
  16. from hydra.utils import instantiate
  17. from loguru import logger
  18. from tqdm import tqdm
  19. from transformers import AutoTokenizer
  20. from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
  21. from fish_speech.text import clean_text, split_text
  22. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  23. torch._inductor.config.coordinate_descent_tuning = True
  24. torch._inductor.config.triton.unique_kernel_names = True
  25. if hasattr(torch._inductor.config, "fx_graph_cache"):
  26. # Experimental feature to reduce compilation times, will be on by default in future
  27. torch._inductor.config.fx_graph_cache = True
  28. from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
  29. def multinomial_sample_one_no_sync(
  30. probs_sort,
  31. ): # Does multinomial sampling without a cuda synchronization
  32. q = torch.empty_like(probs_sort).exponential_(1)
  33. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  34. def logits_to_probs(
  35. logits,
  36. previous_tokens: Optional[torch.Tensor] = None,
  37. temperature: torch.Tensor = 1.0,
  38. top_p: torch.Tensor = 1.0,
  39. repetition_penalty: torch.Tensor = 1.0,
  40. ) -> torch.Tensor:
  41. # Apply repetition penalty
  42. if previous_tokens is not None:
  43. previous_tokens = previous_tokens.long()
  44. score = torch.gather(logits, dim=0, index=previous_tokens)
  45. score = torch.where(
  46. score < 0, score * repetition_penalty, score / repetition_penalty
  47. )
  48. logits.scatter_(dim=0, index=previous_tokens, src=score)
  49. # Apply top-p sampling
  50. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  51. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  52. sorted_indices_to_remove = cum_probs > top_p
  53. sorted_indices_to_remove[0] = False # keep at least one option
  54. indices_to_remove = sorted_indices_to_remove.scatter(
  55. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  56. )
  57. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  58. logits = logits / max(temperature, 1e-5)
  59. probs = torch.nn.functional.softmax(logits, dim=-1)
  60. return probs
  61. def sample(
  62. logits,
  63. previous_tokens: Optional[torch.Tensor] = None,
  64. **sampling_kwargs,
  65. ) -> Tuple[torch.Tensor, torch.Tensor]:
  66. probs = logits_to_probs(
  67. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  68. )
  69. idx_next = multinomial_sample_one_no_sync(probs)
  70. return idx_next, probs
  71. def decode_one_token_ar(
  72. model: DualARTransformer,
  73. x: torch.Tensor,
  74. input_pos: torch.Tensor,
  75. previous_tokens: torch.Tensor = None,
  76. **sampling_kwargs,
  77. ) -> torch.Tensor:
  78. x = model.forward_generate(x, input_pos)
  79. codebooks = [
  80. sample(
  81. x.logits,
  82. previous_tokens=None, # Disable repetition penalty for the token codebook
  83. **sampling_kwargs,
  84. )[0]
  85. ]
  86. x = x.hidden_states
  87. # Cleanup the cache
  88. for layer in model.fast_layers:
  89. layer.attention.kv_cache.k_cache.fill_(0)
  90. layer.attention.kv_cache.v_cache.fill_(0)
  91. for codebook_idx in range(model.config.num_codebooks):
  92. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  93. logits = model.forward_generate_fast(x, input_pos)
  94. a = sample(
  95. logits,
  96. previous_tokens=(
  97. previous_tokens[codebook_idx + 1]
  98. if previous_tokens is not None
  99. else None
  100. ),
  101. **sampling_kwargs,
  102. )[0]
  103. x = model.fast_embeddings(a)
  104. codebooks.append(a)
  105. return torch.stack(codebooks, dim=0)
  106. def decode_one_token_naive(
  107. model: NaiveTransformer,
  108. x: torch.Tensor,
  109. input_pos: torch.Tensor,
  110. previous_tokens: torch.Tensor = None,
  111. **sampling_kwargs,
  112. ) -> torch.Tensor:
  113. x = model.forward_generate(x, input_pos)
  114. codebooks = [
  115. sample(
  116. x.token_logits,
  117. previous_tokens=None, # Disable repetition penalty for the token codebook
  118. **sampling_kwargs,
  119. )[0]
  120. ]
  121. for i in range(model.config.num_codebooks):
  122. codebooks.append(
  123. sample(
  124. x.codebook_logits[:, :, i],
  125. previous_tokens=(
  126. previous_tokens[i + 1] if previous_tokens is not None else None
  127. ),
  128. **sampling_kwargs,
  129. )[0]
  130. )
  131. return torch.stack(codebooks, dim=0)
  132. def decode_n_tokens(
  133. model: NaiveTransformer,
  134. cur_token: torch.Tensor,
  135. input_pos: torch.Tensor,
  136. num_new_tokens: int,
  137. eos_token_id: int = 2,
  138. im_end_id: int = 4,
  139. decode_one_token=decode_one_token_naive,
  140. **sampling_kwargs,
  141. ):
  142. previous_tokens = torch.zeros(
  143. (model.config.num_codebooks + 1, model.config.max_seq_len),
  144. dtype=torch.int,
  145. device=cur_token.device,
  146. )
  147. for i in tqdm(range(num_new_tokens)):
  148. # We need to get windowed repeat penalty
  149. win_size = 16
  150. if i < win_size:
  151. window = previous_tokens[:, :win_size]
  152. else:
  153. window = previous_tokens[:, i - win_size : i]
  154. with torch.backends.cuda.sdp_kernel(
  155. enable_flash=False, enable_mem_efficient=False, enable_math=True
  156. ): # Actually better for Inductor to codegen attention here
  157. next_token = decode_one_token(
  158. model=model,
  159. x=cur_token,
  160. input_pos=input_pos,
  161. previous_tokens=window,
  162. **sampling_kwargs,
  163. )
  164. input_pos += 1
  165. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  166. previous_tokens[:, i : i + 1] = next_token.view(
  167. model.config.num_codebooks + 1, -1
  168. )
  169. if (
  170. cur_token[0, 0, -1] == eos_token_id
  171. or cur_token[0, 0, -1] == im_end_id
  172. or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
  173. ):
  174. break
  175. return previous_tokens[:, : i + 1]
  176. @torch.no_grad()
  177. @torch.inference_mode()
  178. def generate(
  179. *,
  180. model: NaiveTransformer,
  181. prompt: torch.Tensor,
  182. max_new_tokens: int,
  183. eos_token_id: int = 2,
  184. im_end_id: int = 4,
  185. decode_one_token=decode_one_token_naive,
  186. **sampling_kwargs,
  187. ) -> torch.Tensor:
  188. """
  189. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  190. """
  191. # create an empty tensor of the expected final shape and fill in the current tokens
  192. T = prompt.size(1)
  193. if max_new_tokens:
  194. if T + max_new_tokens > model.config.max_seq_len:
  195. max_new_tokens = model.config.max_seq_len - T
  196. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  197. T_new = T + max_new_tokens
  198. else:
  199. T_new = model.config.max_seq_len
  200. max_new_tokens = T_new - T
  201. device, dtype = prompt.device, prompt.dtype
  202. with torch.device(device):
  203. model.setup_caches(
  204. max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
  205. )
  206. codebook_dim = 1 + model.config.num_codebooks
  207. # create an empty tensor of the expected final shape and fill in the current tokens
  208. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  209. empty[:, :T] = prompt
  210. seq = empty
  211. input_pos = torch.arange(0, T, device=device)
  212. # Use non-accelerated version for now, to avoid compilation overhead
  213. prefill_decode = (
  214. decode_one_token_naive
  215. if isinstance(model, NaiveTransformer)
  216. else decode_one_token_ar
  217. )
  218. next_token = prefill_decode(
  219. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  220. )
  221. seq[:, T : T + 1] = next_token
  222. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  223. x = decode_n_tokens(
  224. model,
  225. next_token.view(1, codebook_dim, -1),
  226. input_pos,
  227. max_new_tokens - 1,
  228. eos_token_id=eos_token_id,
  229. im_end_id=im_end_id,
  230. decode_one_token=decode_one_token,
  231. **sampling_kwargs,
  232. )
  233. # x = torch.cat(generated_tokens, dim=1)
  234. seq = seq[:, : T + 1 + x.size(1)]
  235. seq[:, T + 1 :] = x
  236. return seq
  237. def encode_tokens(
  238. tokenizer,
  239. string,
  240. bos=True,
  241. device="cuda",
  242. prompt_tokens=None,
  243. speaker=None,
  244. num_codebooks=4,
  245. ):
  246. string = clean_text(string)
  247. if speaker is None:
  248. speaker = "assistant"
  249. string = (
  250. f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
  251. )
  252. if bos:
  253. string = f"<|begin_of_sequence|>{string}"
  254. new_tokens = tokenizer.encode(
  255. string,
  256. add_special_tokens=False,
  257. max_length=10**6,
  258. truncation=False,
  259. )
  260. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  261. # Codebooks
  262. zeros = (
  263. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  264. * CODEBOOK_PAD_TOKEN_ID
  265. )
  266. prompt = torch.cat((tokens, zeros), dim=0)
  267. if prompt_tokens is None:
  268. return prompt
  269. # Get prompt tokens
  270. if prompt_tokens.ndim == 3:
  271. assert (
  272. prompt_tokens.shape[0] == 1
  273. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  274. prompt_tokens = prompt_tokens[0]
  275. assert prompt_tokens.ndim == 2
  276. data = prompt_tokens + 2
  277. if prompt_tokens.shape[0] > num_codebooks:
  278. logger.warning(
  279. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  280. )
  281. data = data[:num_codebooks]
  282. # Add eos token for each codebook
  283. data = torch.cat(
  284. (
  285. data,
  286. torch.ones((data.size(0), 1), dtype=torch.int, device=device)
  287. * CODEBOOK_EOS_TOKEN_ID,
  288. ),
  289. dim=1,
  290. )
  291. # Since 1.0, we use <|semantic|>
  292. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  293. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  294. main_token_ids = (
  295. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  296. )
  297. main_token_ids[0, -1] = end_token_id
  298. data = torch.cat((main_token_ids, data), dim=0)
  299. prompt = torch.cat((prompt, data), dim=1)
  300. return prompt
  301. def load_model(
  302. config_name, checkpoint_path, device, precision, max_length, compile=False
  303. ):
  304. hydra.core.global_hydra.GlobalHydra.instance().clear()
  305. with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
  306. cfg = compose(
  307. config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
  308. )
  309. model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
  310. if "int8" in str(checkpoint_path):
  311. logger.info("Using int8 weight-only quantization!")
  312. from .quantize import WeightOnlyInt8QuantHandler
  313. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  314. model = simple_quantizer.convert_for_runtime()
  315. if "int4" in str(checkpoint_path):
  316. logger.info("Using int4 quantization!")
  317. path_comps = checkpoint_path.name.split(".")
  318. assert path_comps[-2].startswith("g")
  319. groupsize = int(path_comps[-2][1:])
  320. from .quantize import WeightOnlyInt4QuantHandler
  321. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  322. model = simple_quantizer.convert_for_runtime()
  323. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  324. if "state_dict" in checkpoint:
  325. checkpoint = checkpoint["state_dict"]
  326. if any(k.startswith("model.") for k in checkpoint):
  327. checkpoint = {
  328. k.replace("model.", ""): v
  329. for k, v in checkpoint.items()
  330. if k.startswith("model.")
  331. }
  332. model.load_state_dict(checkpoint, assign=True)
  333. model = model.to(device=device, dtype=precision)
  334. logger.info("Restored model from checkpoint")
  335. if isinstance(model, DualARTransformer):
  336. decode_one_token = decode_one_token_ar
  337. logger.info("Using DualARTransformer")
  338. else:
  339. decode_one_token = decode_one_token_naive
  340. logger.info("Using NaiveTransformer")
  341. if compile:
  342. logger.info("Compiling function...")
  343. decode_one_token = torch.compile(
  344. decode_one_token, mode="reduce-overhead", fullgraph=True
  345. )
  346. return model.eval(), decode_one_token
  347. @dataclass
  348. class GenerateResponse:
  349. action: Literal["sample", "next"]
  350. codes: Optional[torch.Tensor] = None
  351. text: Optional[str] = None
  352. def generate_long(
  353. *,
  354. model,
  355. tokenizer: callable,
  356. device: str | torch.device,
  357. decode_one_token: callable,
  358. text: str,
  359. num_samples: int = 1,
  360. max_new_tokens: int = 0,
  361. top_p: int = 0.7,
  362. repetition_penalty: float = 1.5,
  363. temperature: float = 0.7,
  364. compile: bool = False,
  365. iterative_prompt: bool = True,
  366. max_length: int = 2048,
  367. chunk_length: int = 150,
  368. speaker: Optional[str] = None,
  369. prompt_text: Optional[str] = None,
  370. prompt_tokens: Optional[torch.Tensor] = None,
  371. ):
  372. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  373. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  374. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  375. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  376. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  377. use_prompt = prompt_text is not None and prompt_tokens is not None
  378. encoded = []
  379. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  380. if use_prompt:
  381. encoded_prompts = encode_tokens(
  382. tokenizer,
  383. prompt_text,
  384. prompt_tokens=prompt_tokens,
  385. bos=True,
  386. device=device,
  387. speaker=speaker,
  388. num_codebooks=model.config.num_codebooks,
  389. )
  390. for idx, text in enumerate(texts):
  391. encoded.append(
  392. encode_tokens(
  393. tokenizer,
  394. string=text,
  395. bos=idx == 0 and not use_prompt,
  396. device=device,
  397. speaker=speaker,
  398. num_codebooks=model.config.num_codebooks,
  399. )
  400. )
  401. logger.info(f"Encoded text: {text}")
  402. # Move temperature, top_p, repetition_penalty to device
  403. # This is important so that changing params doesn't trigger recompile
  404. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  405. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  406. repetition_penalty = torch.tensor(
  407. repetition_penalty, device=device, dtype=torch.float
  408. )
  409. for sample_idx in range(num_samples):
  410. if torch.cuda.is_available():
  411. torch.cuda.synchronize()
  412. global_encoded = []
  413. seg_idx = 0
  414. while seg_idx < len(encoded):
  415. logger.info(
  416. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  417. )
  418. seg = encoded[seg_idx]
  419. global_encoded.append(seg)
  420. lengths = reversed([seg.size(1) for seg in global_encoded])
  421. # Pick last 2000 tokens
  422. count = 0
  423. for i, length in enumerate(lengths):
  424. count += length
  425. if count + length > max_length - 1024:
  426. break
  427. if i != 0 and i % 2 == 0:
  428. i -= 1
  429. # Rotate the list, always make sure first segment is included to avoid drift
  430. if i < len(global_encoded) - 2:
  431. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  432. else:
  433. partial_encoded = global_encoded
  434. if use_prompt:
  435. partial_encoded = [encoded_prompts] + partial_encoded
  436. cat_encoded = torch.cat(partial_encoded, dim=1)
  437. prompt_length = cat_encoded.size(1)
  438. t0 = time.perf_counter()
  439. y = generate(
  440. model=model,
  441. prompt=cat_encoded,
  442. max_new_tokens=max_new_tokens,
  443. eos_token_id=tokenizer.eos_token_id,
  444. im_end_id=im_end_id,
  445. decode_one_token=decode_one_token,
  446. temperature=temperature,
  447. top_p=top_p,
  448. repetition_penalty=repetition_penalty,
  449. )
  450. if sample_idx == 0 and seg_idx == 0 and compile:
  451. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  452. if torch.cuda.is_available():
  453. torch.cuda.synchronize()
  454. t = time.perf_counter() - t0
  455. tokens_generated = y.size(1) - prompt_length
  456. tokens_sec = tokens_generated / t
  457. logger.info(
  458. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  459. )
  460. logger.info(
  461. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  462. )
  463. if torch.cuda.is_available():
  464. logger.info(
  465. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  466. )
  467. # Put the generated tokens
  468. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  469. codes = y[1:, prompt_length:-2].clone()
  470. codes = codes - 2
  471. assert (codes >= 0).all(), f"Negative code found"
  472. decoded = y[:, prompt_length:-1].clone()
  473. if decoded[0, -1] != im_end_id: # <im_end>
  474. val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
  475. decoded = torch.cat(
  476. (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
  477. )
  478. # But for global encoding, we should keep the <im_end> token
  479. global_encoded.append(decoded)
  480. assert (codes >= 0).all(), f"Negative code found: {codes}"
  481. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  482. seg_idx += 1
  483. # This indicates the end of the current sample
  484. yield GenerateResponse(action="next")
  485. @dataclass
  486. class WrappedGenerateResponse:
  487. status: Literal["success", "error"]
  488. response: Optional[GenerateResponse | Exception] = None
  489. @dataclass
  490. class GenerateRequest:
  491. request: dict
  492. response_queue: queue.Queue
  493. def launch_thread_safe_queue(
  494. config_name,
  495. checkpoint_path,
  496. device,
  497. precision,
  498. max_length: int,
  499. compile: bool = False,
  500. ):
  501. input_queue = queue.Queue()
  502. init_event = threading.Event()
  503. def worker():
  504. model, decode_one_token = load_model(
  505. config_name, checkpoint_path, device, precision, max_length, compile=compile
  506. )
  507. init_event.set()
  508. while True:
  509. item: GenerateRequest | None = input_queue.get()
  510. if item is None:
  511. break
  512. kwargs = item.request
  513. response_queue = item.response_queue
  514. try:
  515. for chunk in generate_long(
  516. model=model, decode_one_token=decode_one_token, **kwargs
  517. ):
  518. response_queue.put(
  519. WrappedGenerateResponse(status="success", response=chunk)
  520. )
  521. except Exception as e:
  522. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  523. threading.Thread(target=worker, daemon=True).start()
  524. init_event.wait()
  525. return input_queue
  526. @click.command()
  527. @click.option(
  528. "--text",
  529. type=str,
  530. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  531. )
  532. @click.option("--prompt-text", type=str, default=None)
  533. @click.option(
  534. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  535. )
  536. @click.option("--num-samples", type=int, default=1)
  537. @click.option("--max-new-tokens", type=int, default=0)
  538. @click.option("--top-p", type=float, default=0.7)
  539. @click.option("--repetition-penalty", type=float, default=1.5)
  540. @click.option("--temperature", type=float, default=0.7)
  541. @click.option(
  542. "--checkpoint-path",
  543. type=click.Path(path_type=Path, exists=True),
  544. default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
  545. )
  546. @click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
  547. @click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
  548. @click.option("--compile/--no-compile", default=False)
  549. @click.option("--seed", type=int, default=42)
  550. @click.option("--speaker", type=str, default=None)
  551. @click.option("--half/--no-half", default=False)
  552. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  553. @click.option("--max-length", type=int, default=2048)
  554. @click.option("--chunk-length", type=int, default=30)
  555. def main(
  556. text: str,
  557. prompt_text: Optional[str],
  558. prompt_tokens: Optional[Path],
  559. num_samples: int,
  560. max_new_tokens: int,
  561. top_p: int,
  562. repetition_penalty: float,
  563. temperature: float,
  564. checkpoint_path: Path,
  565. config_name: str,
  566. tokenizer: str,
  567. compile: bool,
  568. seed: int,
  569. speaker: Optional[str],
  570. half: bool,
  571. iterative_prompt: bool,
  572. max_length: int,
  573. chunk_length: int,
  574. ) -> None:
  575. device = "cuda"
  576. precision = torch.half if half else torch.bfloat16
  577. logger.info("Loading model ...")
  578. t0 = time.time()
  579. model, decode_one_token = load_model(
  580. config_name, checkpoint_path, device, precision, max_length, compile=compile
  581. )
  582. if torch.cuda.is_available():
  583. torch.cuda.synchronize()
  584. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  585. prompt_tokens = (
  586. torch.from_numpy(np.load(prompt_tokens)).to(device)
  587. if prompt_tokens is not None
  588. else None
  589. )
  590. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  591. torch.manual_seed(seed)
  592. if torch.cuda.is_available():
  593. torch.cuda.manual_seed(seed)
  594. generator = generate_long(
  595. model=model,
  596. device=device,
  597. decode_one_token=decode_one_token,
  598. text=text,
  599. num_samples=num_samples,
  600. max_new_tokens=max_new_tokens,
  601. top_p=top_p,
  602. repetition_penalty=repetition_penalty,
  603. temperature=temperature,
  604. tokenizer=tokenizer,
  605. compile=compile,
  606. speaker=speaker,
  607. iterative_prompt=iterative_prompt,
  608. max_length=max_length,
  609. chunk_length=chunk_length,
  610. prompt_text=prompt_text,
  611. prompt_tokens=prompt_tokens,
  612. )
  613. idx = 0
  614. codes = []
  615. for response in generator:
  616. if response.action == "sample":
  617. codes.append(response.codes)
  618. logger.info(f"Sampled text: {response.text}")
  619. elif response.action == "next":
  620. if codes:
  621. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  622. logger.info(f"Saved codes to codes_{idx}.npy")
  623. logger.info(f"Next sample")
  624. codes = []
  625. idx += 1
  626. else:
  627. logger.error(f"Error: {response}")
  628. if __name__ == "__main__":
  629. main()