generate.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610
  1. import os
  2. import time
  3. from pathlib import Path
  4. from typing import Optional, Tuple, Union
  5. import click
  6. import numpy as np
  7. import torch
  8. import torch._dynamo.config
  9. import torch._inductor.config
  10. from hydra import compose, initialize
  11. from hydra.utils import instantiate
  12. from loguru import logger
  13. from tqdm import tqdm
  14. from transformers import AutoTokenizer
  15. from fish_speech.text.parser import clean_text
  16. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  17. torch._inductor.config.coordinate_descent_tuning = True
  18. torch._inductor.config.triton.unique_kernel_names = True
  19. if hasattr(torch._inductor.config, "fx_graph_cache"):
  20. # Experimental feature to reduce compilation times, will be on by default in future
  21. torch._inductor.config.fx_graph_cache = True
  22. from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
  23. from fish_speech.text import g2p
  24. from fish_speech.text.symbols import pad as pad_symbol
  25. from fish_speech.text.symbols import pu_symbols
  26. def multinomial_sample_one_no_sync(
  27. probs_sort,
  28. ): # Does multinomial sampling without a cuda synchronization
  29. q = torch.empty_like(probs_sort).exponential_(1)
  30. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  31. def logits_to_probs(
  32. logits,
  33. previous_tokens: Optional[torch.Tensor] = None,
  34. temperature: float = 1.0,
  35. top_k: Optional[int] = None,
  36. top_p: Optional[int] = None,
  37. repetition_penalty: float = 1.0,
  38. ):
  39. if previous_tokens is not None and repetition_penalty != 1.0:
  40. previous_tokens = previous_tokens.long()
  41. score = torch.gather(logits, dim=0, index=previous_tokens)
  42. score = torch.where(
  43. score < 0, score * repetition_penalty, score / repetition_penalty
  44. )
  45. logits.scatter_(dim=0, index=previous_tokens, src=score)
  46. if top_p is not None and top_p < 1.0:
  47. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  48. cum_probs = torch.cumsum(
  49. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  50. )
  51. sorted_indices_to_remove = cum_probs > top_p
  52. sorted_indices_to_remove[0] = False # keep at least one option
  53. indices_to_remove = sorted_indices_to_remove.scatter(
  54. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  55. )
  56. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  57. logits = logits / max(temperature, 1e-5)
  58. if top_k is not None:
  59. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  60. pivot = v.select(-1, -1).unsqueeze(-1)
  61. logits = torch.where(logits < pivot, -float("Inf"), logits)
  62. probs = torch.nn.functional.softmax(logits, dim=-1)
  63. return probs
  64. def sample(
  65. logits,
  66. previous_tokens: Optional[torch.Tensor] = None,
  67. **sampling_kwargs,
  68. ) -> Tuple[torch.Tensor, torch.Tensor]:
  69. probs = logits_to_probs(
  70. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  71. )
  72. idx_next = multinomial_sample_one_no_sync(probs)
  73. return idx_next, probs
  74. def decode_one_token_ar(
  75. model: DualARTransformer,
  76. x: torch.Tensor,
  77. input_pos: torch.Tensor,
  78. previous_tokens: torch.Tensor = None,
  79. **sampling_kwargs,
  80. ) -> torch.Tensor:
  81. x = model.forward_generate(x, input_pos)
  82. codebooks = [
  83. sample(
  84. x.logits,
  85. previous_tokens=None, # Disable repetition penalty for the token codebook
  86. **sampling_kwargs,
  87. )[0]
  88. ]
  89. x = x.hidden_states
  90. # Cleanup the cache
  91. for layer in model.fast_layers:
  92. layer.attention.kv_cache.k_cache.fill_(0)
  93. layer.attention.kv_cache.v_cache.fill_(0)
  94. for codebook_idx in range(model.config.num_codebooks):
  95. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  96. logits = model.forward_generate_fast(x, input_pos)
  97. a = sample(
  98. logits,
  99. previous_tokens=(
  100. previous_tokens[codebook_idx + 1]
  101. if previous_tokens is not None
  102. else None
  103. ),
  104. **sampling_kwargs,
  105. )[0]
  106. x = model.fast_embeddings(a)
  107. codebooks.append(a)
  108. return torch.stack(codebooks, dim=0)
  109. def decode_one_token_naive(
  110. model: NaiveTransformer,
  111. x: torch.Tensor,
  112. input_pos: torch.Tensor,
  113. previous_tokens: torch.Tensor = None,
  114. **sampling_kwargs,
  115. ) -> torch.Tensor:
  116. x = model.forward_generate(x, input_pos)
  117. codebooks = [
  118. sample(
  119. x.token_logits,
  120. previous_tokens=None, # Disable repetition penalty for the token codebook
  121. **sampling_kwargs,
  122. )[0]
  123. ]
  124. for i in range(model.config.num_codebooks):
  125. codebooks.append(
  126. sample(
  127. x.codebook_logits[:, :, i],
  128. previous_tokens=previous_tokens[i + 1]
  129. if previous_tokens is not None
  130. else None,
  131. **sampling_kwargs,
  132. )[0]
  133. )
  134. return torch.stack(codebooks, dim=0)
  135. def decode_n_tokens(
  136. model: NaiveTransformer,
  137. cur_token: torch.Tensor,
  138. input_pos: torch.Tensor,
  139. num_new_tokens: int,
  140. eos_token_id: int = 2,
  141. decode_one_token=decode_one_token_naive,
  142. **sampling_kwargs,
  143. ):
  144. previous_tokens = torch.zeros(
  145. (model.config.num_codebooks + 1, model.config.max_seq_len),
  146. dtype=torch.int,
  147. device=cur_token.device,
  148. )
  149. for i in tqdm(range(num_new_tokens)):
  150. # We need to get windowed repeat penalty
  151. win_size = 16
  152. if i < win_size:
  153. window = previous_tokens[:, :win_size]
  154. else:
  155. window = previous_tokens[:, i - win_size : i]
  156. with torch.backends.cuda.sdp_kernel(
  157. enable_flash=False, enable_mem_efficient=False, enable_math=True
  158. ): # Actually better for Inductor to codegen attention here
  159. next_token = decode_one_token(
  160. model=model,
  161. x=cur_token,
  162. input_pos=input_pos,
  163. previous_tokens=window,
  164. **sampling_kwargs,
  165. )
  166. input_pos += 1
  167. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  168. previous_tokens[:, i : i + 1] = next_token.view(
  169. model.config.num_codebooks + 1, -1
  170. )
  171. # TODO: use tokenizer's eos
  172. if cur_token[0, 0, -1] == eos_token_id or (cur_token[0, 1:, -1] == 1).any():
  173. break
  174. return previous_tokens[:, : i + 1]
  175. @torch.no_grad()
  176. @torch.inference_mode()
  177. def generate(
  178. *,
  179. model: NaiveTransformer,
  180. prompt: torch.Tensor,
  181. max_new_tokens: int,
  182. eos_token_id: int = 2,
  183. decode_one_token=decode_one_token_naive,
  184. precision: torch.dtype = torch.bfloat16,
  185. **sampling_kwargs,
  186. ) -> torch.Tensor:
  187. """
  188. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  189. """
  190. # create an empty tensor of the expected final shape and fill in the current tokens
  191. T = prompt.size(1)
  192. if max_new_tokens:
  193. if T + max_new_tokens > model.config.max_seq_len:
  194. max_new_tokens = model.config.max_seq_len - T
  195. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  196. T_new = T + max_new_tokens
  197. else:
  198. T_new = model.config.max_seq_len
  199. max_new_tokens = T_new - T
  200. device, dtype = prompt.device, prompt.dtype
  201. with torch.device(device):
  202. model.setup_caches(max_batch_size=1, max_seq_len=T_new, dtype=precision)
  203. codebook_dim = 1 + model.config.num_codebooks
  204. # create an empty tensor of the expected final shape and fill in the current tokens
  205. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  206. empty[:, :T] = prompt
  207. seq = empty
  208. input_pos = torch.arange(0, T, device=device)
  209. next_token = decode_one_token(
  210. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  211. )
  212. seq[:, T : T + 1] = next_token
  213. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  214. x = decode_n_tokens(
  215. model,
  216. next_token.view(1, codebook_dim, -1),
  217. input_pos,
  218. max_new_tokens - 1,
  219. eos_token_id=eos_token_id,
  220. decode_one_token=decode_one_token,
  221. **sampling_kwargs,
  222. )
  223. # x = torch.cat(generated_tokens, dim=1)
  224. seq = seq[:, : T + 1 + x.size(1)]
  225. seq[:, T + 1 :] = x
  226. return seq
  227. def encode_tokens(
  228. tokenizer,
  229. string,
  230. bos=True,
  231. device="cuda",
  232. prompt_tokens=None,
  233. use_g2p=False,
  234. speaker=None,
  235. order="zh,jp,en",
  236. num_codebooks=4,
  237. ):
  238. if use_g2p:
  239. order = order.split(",")
  240. prompt = g2p(string, order=order)
  241. prompt = [
  242. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  243. for _, i in prompt
  244. ]
  245. string = " ".join(prompt)
  246. else:
  247. string = clean_text(string)
  248. if speaker is not None:
  249. string = f"[SPK: {speaker}] {string}"
  250. string = f"[INST] {string} [/INST]"
  251. new_tokens = tokenizer.encode(
  252. string,
  253. add_special_tokens=bos,
  254. max_length=10**6,
  255. truncation=False,
  256. )
  257. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  258. # Codebooks
  259. zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  260. prompt = torch.cat((tokens, zeros), dim=0)
  261. if prompt_tokens is None:
  262. return prompt
  263. # Get prompt tokens
  264. if prompt_tokens.ndim == 3:
  265. assert (
  266. prompt_tokens.shape[0] == 1
  267. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  268. prompt_tokens = prompt_tokens[0]
  269. assert prompt_tokens.ndim == 2
  270. data = prompt_tokens + 2
  271. if prompt_tokens.shape[0] > num_codebooks:
  272. logger.warning(
  273. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  274. )
  275. data = data[:num_codebooks]
  276. # Since 1.0, we use <s:xxx> to replace <semantic>
  277. s0_token_id = tokenizer.convert_tokens_to_ids("<s:0>")
  278. main_token_ids = torch.tensor(
  279. # TODO: replace this
  280. [[s0_token_id] * data.size(1)],
  281. dtype=torch.int,
  282. device=device,
  283. )
  284. data = torch.cat((main_token_ids, data), dim=0)
  285. prompt = torch.cat((prompt, data), dim=1)
  286. return prompt
  287. def load_model(config_name, checkpoint_path, device, precision, max_length):
  288. with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
  289. cfg = compose(
  290. config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
  291. )
  292. model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
  293. if "int8" in str(checkpoint_path):
  294. logger.info("Using int8 weight-only quantization!")
  295. from quantize import WeightOnlyInt8QuantHandler
  296. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  297. model = simple_quantizer.convert_for_runtime()
  298. if "int4" in str(checkpoint_path):
  299. logger.info("Using int4 quantization!")
  300. path_comps = checkpoint_path.name.split(".")
  301. assert path_comps[-2].startswith("g")
  302. groupsize = int(path_comps[-2][1:])
  303. from quantize import WeightOnlyInt4QuantHandler
  304. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  305. model = simple_quantizer.convert_for_runtime()
  306. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  307. if "state_dict" in checkpoint:
  308. checkpoint = checkpoint["state_dict"]
  309. if any(k.startswith("model.") for k in checkpoint):
  310. checkpoint = {
  311. k.replace("model.", ""): v
  312. for k, v in checkpoint.items()
  313. if k.startswith("model.")
  314. }
  315. model.load_state_dict(checkpoint, assign=True)
  316. model = model.to(device=device, dtype=precision)
  317. logger.info("Restored model from checkpoint")
  318. return model.eval(), cfg
  319. def split_text(text, min_length):
  320. text = clean_text(text)
  321. segments = []
  322. curr = ""
  323. for char in text:
  324. curr += char
  325. if char not in [".", ",", "!", "?"]:
  326. continue
  327. if len(curr) >= min_length:
  328. segments.append(curr)
  329. curr = ""
  330. if curr:
  331. segments.append(curr)
  332. return segments
  333. @click.command()
  334. @click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
  335. @click.option("--prompt-text", type=str, default=None)
  336. @click.option(
  337. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  338. )
  339. @click.option("--num-samples", type=int, default=1)
  340. @click.option("--max-new-tokens", type=int, default=0)
  341. @click.option("--top-k", type=int, default=None)
  342. @click.option("--top-p", type=float, default=0.9)
  343. @click.option("--repetition-penalty", type=float, default=1.2)
  344. @click.option("--temperature", type=float, default=0.7)
  345. @click.option(
  346. "--checkpoint-path",
  347. type=click.Path(path_type=Path, exists=True),
  348. default="results/text2semantic_400m_finetune/step_000002000.pth",
  349. )
  350. @click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
  351. @click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
  352. @click.option("--compile/--no-compile", default=False)
  353. @click.option("--use-g2p/--no-g2p", default=True)
  354. @click.option("--seed", type=int, default=42)
  355. @click.option("--speaker", type=str, default=None)
  356. @click.option("--order", type=str, default="zh,jp,en")
  357. @click.option("--half/--no-half", default=False)
  358. @click.option("--iterative-prompt/--no-iterative-prompt", default=False)
  359. @click.option("--max-length", type=int, default=2048)
  360. @click.option("--chunk-length", type=int, default=30)
  361. def main(
  362. text: str,
  363. prompt_text: Optional[str],
  364. prompt_tokens: Optional[Path],
  365. num_samples: int,
  366. max_new_tokens: int,
  367. top_k: int,
  368. top_p: int,
  369. repetition_penalty: float,
  370. temperature: float,
  371. checkpoint_path: Path,
  372. config_name: str,
  373. tokenizer: str,
  374. compile: bool,
  375. use_g2p: bool,
  376. seed: int,
  377. speaker: Optional[str],
  378. order: str,
  379. half: bool,
  380. iterative_prompt: bool,
  381. max_length: int,
  382. chunk_length: int,
  383. ) -> None:
  384. device = "cuda"
  385. precision = torch.half if half else torch.bfloat16
  386. logger.info("Loading model ...")
  387. t0 = time.time()
  388. model, cfg = load_model(config_name, checkpoint_path, device, precision, max_length)
  389. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  390. torch.cuda.synchronize()
  391. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  392. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  393. prompt_tokens = (
  394. torch.from_numpy(np.load(prompt_tokens)).to(device)
  395. if prompt_tokens is not None
  396. else None
  397. )
  398. use_prompt = prompt_text is not None and prompt_tokens is not None
  399. encoded = []
  400. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  401. for idx, text in enumerate(texts):
  402. encoded.append(
  403. encode_tokens(
  404. tokenizer,
  405. string=text,
  406. bos=idx == 0 and not use_prompt,
  407. device=device,
  408. use_g2p=use_g2p,
  409. speaker=None,
  410. order=order,
  411. num_codebooks=model.config.num_codebooks,
  412. )
  413. )
  414. logger.info(f"Encoded text: {text}")
  415. if use_prompt:
  416. encoded_prompt = encode_tokens(
  417. tokenizer,
  418. prompt_text,
  419. prompt_tokens=prompt_tokens,
  420. bos=True,
  421. device=device,
  422. use_g2p=use_g2p,
  423. speaker=speaker,
  424. order=order,
  425. num_codebooks=model.config.num_codebooks,
  426. )
  427. encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
  428. torch.manual_seed(seed)
  429. torch.cuda.manual_seed(seed)
  430. if isinstance(model, DualARTransformer):
  431. decode_one_token = decode_one_token_ar
  432. logger.info("Using DualARTransformer")
  433. else:
  434. decode_one_token = decode_one_token_naive
  435. logger.info("Using NaiveTransformer")
  436. if compile:
  437. logger.info("Compiling function...")
  438. decode_one_token = torch.compile(
  439. decode_one_token, mode="reduce-overhead", fullgraph=True
  440. )
  441. for idx in range(num_samples):
  442. torch.cuda.synchronize()
  443. global_encoded = []
  444. all_codes = []
  445. seg_idx = 0
  446. while seg_idx < len(encoded):
  447. seg = encoded[seg_idx]
  448. global_encoded.append(seg)
  449. lengths = reversed([seg.size(1) for seg in global_encoded])
  450. # Pick last 2000 tokens
  451. count = 0
  452. for i, length in enumerate(lengths):
  453. count += length
  454. if count + length > max_length - 1024:
  455. break
  456. if i != 0 and i % 2 == 0:
  457. i -= 1
  458. # Rotate the list
  459. if i < len(global_encoded) - 2:
  460. partial_encoded = global_encoded[-i:]
  461. else:
  462. partial_encoded = global_encoded
  463. cat_encoded = torch.cat(partial_encoded, dim=1)
  464. prompt_length = cat_encoded.size(1)
  465. t0 = time.perf_counter()
  466. y = generate(
  467. model=model,
  468. prompt=cat_encoded,
  469. max_new_tokens=max_new_tokens,
  470. eos_token_id=tokenizer.eos_token_id,
  471. decode_one_token=decode_one_token,
  472. precision=precision,
  473. temperature=temperature,
  474. top_k=top_k,
  475. top_p=top_p,
  476. repetition_penalty=repetition_penalty,
  477. )
  478. if idx == 0 and seg_idx == 0 and compile:
  479. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  480. torch.cuda.synchronize()
  481. t = time.perf_counter() - t0
  482. tokens_generated = y.size(1) - prompt_length
  483. tokens_sec = tokens_generated / t
  484. logger.info(
  485. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  486. )
  487. logger.info(
  488. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  489. )
  490. logger.info(
  491. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  492. )
  493. # Put the generated tokens
  494. codes = y[1:, prompt_length:-1].clone()
  495. codes = codes - 2
  496. if not (codes >= 0).all():
  497. global_encoded.pop()
  498. logger.warning(f"Negative code found: {codes}, retrying ...")
  499. continue
  500. global_encoded.append(y[:, prompt_length:-1].clone())
  501. all_codes.append(codes)
  502. seg_idx += 1
  503. codes = torch.cat(all_codes, dim=1)
  504. assert (codes >= 0).all(), f"Negative code found: {codes}"
  505. np.save(f"codes_{idx}.npy", codes.cpu().numpy())
  506. logger.info(f"Saved codes to codes_{idx}.npy")
  507. if __name__ == "__main__":
  508. main()