generate.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. import os
  2. import time
  3. from pathlib import Path
  4. from typing import Optional, Tuple, Union
  5. import click
  6. import numpy as np
  7. import torch
  8. import torch._dynamo.config
  9. import torch._inductor.config
  10. from hydra import compose, initialize
  11. from hydra.utils import instantiate
  12. from loguru import logger
  13. from tqdm import tqdm
  14. from transformers import AutoTokenizer
  15. from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID
  16. from fish_speech.text.clean import clean_text
  17. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  18. torch._inductor.config.coordinate_descent_tuning = True
  19. torch._inductor.config.triton.unique_kernel_names = True
  20. if hasattr(torch._inductor.config, "fx_graph_cache"):
  21. # Experimental feature to reduce compilation times, will be on by default in future
  22. torch._inductor.config.fx_graph_cache = True
  23. from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
  24. def multinomial_sample_one_no_sync(
  25. probs_sort,
  26. ): # Does multinomial sampling without a cuda synchronization
  27. q = torch.empty_like(probs_sort).exponential_(1)
  28. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  29. def logits_to_probs(
  30. logits,
  31. previous_tokens: Optional[torch.Tensor] = None,
  32. temperature: float = 1.0,
  33. top_k: Optional[int] = None,
  34. top_p: Optional[int] = None,
  35. repetition_penalty: float = 1.0,
  36. ):
  37. if previous_tokens is not None and repetition_penalty != 1.0:
  38. previous_tokens = previous_tokens.long()
  39. score = torch.gather(logits, dim=0, index=previous_tokens)
  40. score = torch.where(
  41. score < 0, score * repetition_penalty, score / repetition_penalty
  42. )
  43. logits.scatter_(dim=0, index=previous_tokens, src=score)
  44. if top_p is not None and top_p < 1.0:
  45. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  46. cum_probs = torch.cumsum(
  47. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  48. )
  49. sorted_indices_to_remove = cum_probs > top_p
  50. sorted_indices_to_remove[0] = False # keep at least one option
  51. indices_to_remove = sorted_indices_to_remove.scatter(
  52. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  53. )
  54. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  55. logits = logits / max(temperature, 1e-5)
  56. if top_k is not None:
  57. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  58. pivot = v.select(-1, -1).unsqueeze(-1)
  59. logits = torch.where(logits < pivot, -float("Inf"), logits)
  60. probs = torch.nn.functional.softmax(logits, dim=-1)
  61. return probs
  62. def sample(
  63. logits,
  64. previous_tokens: Optional[torch.Tensor] = None,
  65. **sampling_kwargs,
  66. ) -> Tuple[torch.Tensor, torch.Tensor]:
  67. probs = logits_to_probs(
  68. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  69. )
  70. idx_next = multinomial_sample_one_no_sync(probs)
  71. return idx_next, probs
  72. def decode_one_token_ar(
  73. model: DualARTransformer,
  74. x: torch.Tensor,
  75. input_pos: torch.Tensor,
  76. previous_tokens: torch.Tensor = None,
  77. **sampling_kwargs,
  78. ) -> torch.Tensor:
  79. x = model.forward_generate(x, input_pos)
  80. codebooks = [
  81. sample(
  82. x.logits,
  83. previous_tokens=None, # Disable repetition penalty for the token codebook
  84. **sampling_kwargs,
  85. )[0]
  86. ]
  87. x = x.hidden_states
  88. # Cleanup the cache
  89. for layer in model.fast_layers:
  90. layer.attention.kv_cache.k_cache.fill_(0)
  91. layer.attention.kv_cache.v_cache.fill_(0)
  92. for codebook_idx in range(model.config.num_codebooks):
  93. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  94. logits = model.forward_generate_fast(x, input_pos)
  95. a = sample(
  96. logits,
  97. previous_tokens=(
  98. previous_tokens[codebook_idx + 1]
  99. if previous_tokens is not None
  100. else None
  101. ),
  102. **sampling_kwargs,
  103. )[0]
  104. x = model.fast_embeddings(a)
  105. codebooks.append(a)
  106. return torch.stack(codebooks, dim=0)
  107. def decode_one_token_naive(
  108. model: NaiveTransformer,
  109. x: torch.Tensor,
  110. input_pos: torch.Tensor,
  111. previous_tokens: torch.Tensor = None,
  112. **sampling_kwargs,
  113. ) -> torch.Tensor:
  114. x = model.forward_generate(x, input_pos)
  115. codebooks = [
  116. sample(
  117. x.token_logits,
  118. previous_tokens=None, # Disable repetition penalty for the token codebook
  119. **sampling_kwargs,
  120. )[0]
  121. ]
  122. for i in range(model.config.num_codebooks):
  123. codebooks.append(
  124. sample(
  125. x.codebook_logits[:, :, i],
  126. previous_tokens=(
  127. previous_tokens[i + 1] if previous_tokens is not None else None
  128. ),
  129. **sampling_kwargs,
  130. )[0]
  131. )
  132. return torch.stack(codebooks, dim=0)
  133. def decode_n_tokens(
  134. model: NaiveTransformer,
  135. cur_token: torch.Tensor,
  136. input_pos: torch.Tensor,
  137. num_new_tokens: int,
  138. eos_token_id: int = 2,
  139. im_end_id: int = 4,
  140. decode_one_token=decode_one_token_naive,
  141. **sampling_kwargs,
  142. ):
  143. previous_tokens = torch.zeros(
  144. (model.config.num_codebooks + 1, model.config.max_seq_len),
  145. dtype=torch.int,
  146. device=cur_token.device,
  147. )
  148. for i in tqdm(range(num_new_tokens)):
  149. # We need to get windowed repeat penalty
  150. win_size = 16
  151. if i < win_size:
  152. window = previous_tokens[:, :win_size]
  153. else:
  154. window = previous_tokens[:, i - win_size : i]
  155. with torch.backends.cuda.sdp_kernel(
  156. enable_flash=False, enable_mem_efficient=False, enable_math=True
  157. ): # Actually better for Inductor to codegen attention here
  158. next_token = decode_one_token(
  159. model=model,
  160. x=cur_token,
  161. input_pos=input_pos,
  162. previous_tokens=window,
  163. **sampling_kwargs,
  164. )
  165. input_pos += 1
  166. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  167. previous_tokens[:, i : i + 1] = next_token.view(
  168. model.config.num_codebooks + 1, -1
  169. )
  170. if (
  171. cur_token[0, 0, -1] == eos_token_id
  172. or cur_token[0, 0, -1] == im_end_id
  173. or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
  174. ):
  175. break
  176. return previous_tokens[:, : i + 1]
  177. @torch.no_grad()
  178. @torch.inference_mode()
  179. def generate(
  180. *,
  181. model: NaiveTransformer,
  182. prompt: torch.Tensor,
  183. max_new_tokens: int,
  184. eos_token_id: int = 2,
  185. im_end_id: int = 4,
  186. decode_one_token=decode_one_token_naive,
  187. precision: torch.dtype = torch.bfloat16,
  188. **sampling_kwargs,
  189. ) -> torch.Tensor:
  190. """
  191. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  192. """
  193. # create an empty tensor of the expected final shape and fill in the current tokens
  194. T = prompt.size(1)
  195. if max_new_tokens:
  196. if T + max_new_tokens > model.config.max_seq_len:
  197. max_new_tokens = model.config.max_seq_len - T
  198. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  199. T_new = T + max_new_tokens
  200. else:
  201. T_new = model.config.max_seq_len
  202. max_new_tokens = T_new - T
  203. device, dtype = prompt.device, prompt.dtype
  204. with torch.device(device):
  205. model.setup_caches(max_batch_size=1, max_seq_len=T_new, dtype=precision)
  206. codebook_dim = 1 + model.config.num_codebooks
  207. # create an empty tensor of the expected final shape and fill in the current tokens
  208. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  209. empty[:, :T] = prompt
  210. seq = empty
  211. input_pos = torch.arange(0, T, device=device)
  212. next_token = decode_one_token(
  213. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  214. )
  215. seq[:, T : T + 1] = next_token
  216. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  217. x = decode_n_tokens(
  218. model,
  219. next_token.view(1, codebook_dim, -1),
  220. input_pos,
  221. max_new_tokens - 1,
  222. eos_token_id=eos_token_id,
  223. im_end_id=im_end_id,
  224. decode_one_token=decode_one_token,
  225. **sampling_kwargs,
  226. )
  227. # x = torch.cat(generated_tokens, dim=1)
  228. seq = seq[:, : T + 1 + x.size(1)]
  229. seq[:, T + 1 :] = x
  230. return seq
  231. def encode_tokens(
  232. tokenizer,
  233. string,
  234. bos=True,
  235. device="cuda",
  236. prompt_tokens=None,
  237. speaker=None,
  238. num_codebooks=4,
  239. ):
  240. string = clean_text(string)
  241. if speaker is not None:
  242. string = f"[SPK: {speaker}] {string}"
  243. string = (
  244. f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>assistant<|im_sep|>"
  245. )
  246. if bos:
  247. string = f"<|begin_of_sequence|>{string}"
  248. new_tokens = tokenizer.encode(
  249. string,
  250. add_special_tokens=False,
  251. max_length=10**6,
  252. truncation=False,
  253. )
  254. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  255. # Codebooks
  256. zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  257. prompt = torch.cat((tokens, zeros), dim=0)
  258. if prompt_tokens is None:
  259. return prompt
  260. # Get prompt tokens
  261. if prompt_tokens.ndim == 3:
  262. assert (
  263. prompt_tokens.shape[0] == 1
  264. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  265. prompt_tokens = prompt_tokens[0]
  266. assert prompt_tokens.ndim == 2
  267. data = prompt_tokens + 2
  268. if prompt_tokens.shape[0] > num_codebooks:
  269. logger.warning(
  270. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  271. )
  272. data = data[:num_codebooks]
  273. # Since 1.0, we use <|semantic|>
  274. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  275. main_token_ids = torch.tensor(
  276. [[s0_token_id] * data.size(1)],
  277. dtype=torch.int,
  278. device=device,
  279. )
  280. data = torch.cat((main_token_ids, data), dim=0)
  281. prompt = torch.cat((prompt, data), dim=1)
  282. return prompt
  283. def load_model(config_name, checkpoint_path, device, precision, max_length):
  284. with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
  285. cfg = compose(
  286. config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
  287. )
  288. model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
  289. if "int8" in str(checkpoint_path):
  290. logger.info("Using int8 weight-only quantization!")
  291. from quantize import WeightOnlyInt8QuantHandler
  292. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  293. model = simple_quantizer.convert_for_runtime()
  294. if "int4" in str(checkpoint_path):
  295. logger.info("Using int4 quantization!")
  296. path_comps = checkpoint_path.name.split(".")
  297. assert path_comps[-2].startswith("g")
  298. groupsize = int(path_comps[-2][1:])
  299. from quantize import WeightOnlyInt4QuantHandler
  300. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  301. model = simple_quantizer.convert_for_runtime()
  302. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  303. if "state_dict" in checkpoint:
  304. checkpoint = checkpoint["state_dict"]
  305. if any(k.startswith("model.") for k in checkpoint):
  306. checkpoint = {
  307. k.replace("model.", ""): v
  308. for k, v in checkpoint.items()
  309. if k.startswith("model.")
  310. }
  311. model.load_state_dict(checkpoint, assign=True)
  312. model = model.to(device=device, dtype=precision)
  313. logger.info("Restored model from checkpoint")
  314. return model.eval(), cfg
  315. def split_text(text, min_length):
  316. text = clean_text(text)
  317. segments = []
  318. curr = ""
  319. for char in text:
  320. curr += char
  321. if char not in [".", ",", "!", "?"]:
  322. continue
  323. if len(curr) >= min_length:
  324. segments.append(curr)
  325. curr = ""
  326. if curr:
  327. segments.append(curr)
  328. return segments
  329. @click.command()
  330. @click.option(
  331. "--text",
  332. type=str,
  333. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  334. )
  335. @click.option("--prompt-text", type=str, default=None)
  336. @click.option(
  337. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  338. )
  339. @click.option("--num-samples", type=int, default=1)
  340. @click.option("--max-new-tokens", type=int, default=0)
  341. @click.option("--top-k", type=int, default=None)
  342. @click.option("--top-p", type=float, default=0.7)
  343. @click.option("--repetition-penalty", type=float, default=1.5)
  344. @click.option("--temperature", type=float, default=0.7)
  345. @click.option(
  346. "--checkpoint-path",
  347. type=click.Path(path_type=Path, exists=True),
  348. default="results/text2semantic_400m_finetune/step_000002000.pth",
  349. )
  350. @click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
  351. @click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
  352. @click.option("--compile/--no-compile", default=False)
  353. @click.option("--seed", type=int, default=42)
  354. @click.option("--speaker", type=str, default=None)
  355. @click.option("--half/--no-half", default=False)
  356. @click.option("--iterative-prompt/--no-iterative-prompt", default=False)
  357. @click.option("--max-length", type=int, default=2048)
  358. @click.option("--chunk-length", type=int, default=30)
  359. def main(
  360. text: str,
  361. prompt_text: Optional[str],
  362. prompt_tokens: Optional[Path],
  363. num_samples: int,
  364. max_new_tokens: int,
  365. top_k: int,
  366. top_p: int,
  367. repetition_penalty: float,
  368. temperature: float,
  369. checkpoint_path: Path,
  370. config_name: str,
  371. tokenizer: str,
  372. compile: bool,
  373. seed: int,
  374. speaker: Optional[str],
  375. half: bool,
  376. iterative_prompt: bool,
  377. max_length: int,
  378. chunk_length: int,
  379. ) -> None:
  380. device = "cuda"
  381. precision = torch.half if half else torch.bfloat16
  382. logger.info("Loading model ...")
  383. t0 = time.time()
  384. model, cfg = load_model(config_name, checkpoint_path, device, precision, max_length)
  385. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  386. torch.cuda.synchronize()
  387. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  388. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  389. prompt_tokens = (
  390. torch.from_numpy(np.load(prompt_tokens)).to(device)
  391. if prompt_tokens is not None
  392. else None
  393. )
  394. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  395. use_prompt = prompt_text is not None and prompt_tokens is not None
  396. encoded = []
  397. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  398. for idx, text in enumerate(texts):
  399. encoded.append(
  400. encode_tokens(
  401. tokenizer,
  402. string=text,
  403. bos=idx == 0 and not use_prompt,
  404. device=device,
  405. speaker=None,
  406. num_codebooks=model.config.num_codebooks,
  407. )
  408. )
  409. logger.info(f"Encoded text: {text}")
  410. if use_prompt:
  411. encoded_prompt = encode_tokens(
  412. tokenizer,
  413. prompt_text,
  414. prompt_tokens=prompt_tokens,
  415. bos=True,
  416. device=device,
  417. speaker=speaker,
  418. num_codebooks=model.config.num_codebooks,
  419. )
  420. encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
  421. torch.manual_seed(seed)
  422. torch.cuda.manual_seed(seed)
  423. if isinstance(model, DualARTransformer):
  424. decode_one_token = decode_one_token_ar
  425. logger.info("Using DualARTransformer")
  426. else:
  427. decode_one_token = decode_one_token_naive
  428. logger.info("Using NaiveTransformer")
  429. if compile:
  430. logger.info("Compiling function...")
  431. decode_one_token = torch.compile(
  432. decode_one_token, mode="reduce-overhead", fullgraph=True
  433. )
  434. for idx in range(num_samples):
  435. torch.cuda.synchronize()
  436. global_encoded = []
  437. all_codes = []
  438. seg_idx = 0
  439. while seg_idx < len(encoded):
  440. seg = encoded[seg_idx]
  441. global_encoded.append(seg)
  442. lengths = reversed([seg.size(1) for seg in global_encoded])
  443. # Pick last 2000 tokens
  444. count = 0
  445. for i, length in enumerate(lengths):
  446. count += length
  447. if count + length > max_length - 1024:
  448. break
  449. if i != 0 and i % 2 == 0:
  450. i -= 1
  451. # Rotate the list
  452. if i < len(global_encoded) - 2:
  453. partial_encoded = global_encoded[-i:]
  454. else:
  455. partial_encoded = global_encoded
  456. cat_encoded = torch.cat(partial_encoded, dim=1)
  457. prompt_length = cat_encoded.size(1)
  458. t0 = time.perf_counter()
  459. y = generate(
  460. model=model,
  461. prompt=cat_encoded,
  462. max_new_tokens=max_new_tokens,
  463. eos_token_id=tokenizer.eos_token_id,
  464. im_end_id=im_end_id,
  465. decode_one_token=decode_one_token,
  466. precision=precision,
  467. temperature=temperature,
  468. top_k=top_k,
  469. top_p=top_p,
  470. repetition_penalty=repetition_penalty,
  471. )
  472. if idx == 0 and seg_idx == 0 and compile:
  473. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  474. torch.cuda.synchronize()
  475. t = time.perf_counter() - t0
  476. tokens_generated = y.size(1) - prompt_length
  477. tokens_sec = tokens_generated / t
  478. logger.info(
  479. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  480. )
  481. logger.info(
  482. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  483. )
  484. logger.info(
  485. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  486. )
  487. # Put the generated tokens
  488. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  489. codes = y[1:, prompt_length:-2].clone()
  490. codes = codes - 2
  491. if not (codes >= 0).all():
  492. global_encoded.pop()
  493. logger.warning(f"Negative code found: {codes}, retrying ...")
  494. continue
  495. decoded = y[:, prompt_length:-1].clone()
  496. if decoded[0, -1] != im_end_id: # <im_end>
  497. val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
  498. decoded = torch.cat(
  499. (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
  500. )
  501. # But for global encoding, we should keep the <im_end> token
  502. global_encoded.append(decoded)
  503. all_codes.append(codes)
  504. seg_idx += 1
  505. codes = torch.cat(all_codes, dim=1)
  506. assert (codes >= 0).all(), f"Negative code found: {codes}"
  507. np.save(f"codes_{idx}.npy", codes.cpu().numpy())
  508. logger.info(f"Saved codes to codes_{idx}.npy")
  509. if __name__ == "__main__":
  510. main()