generate.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. import os
  2. import time
  3. from pathlib import Path
  4. from typing import Optional, Tuple
  5. import click
  6. import numpy as np
  7. import torch
  8. import torch._dynamo.config
  9. import torch._inductor.config
  10. from hydra import compose, initialize
  11. from hydra.utils import instantiate
  12. from loguru import logger
  13. from tqdm import tqdm
  14. from transformers import AutoTokenizer
  15. from fish_speech.text.parser import clean_text
  16. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  17. torch._inductor.config.coordinate_descent_tuning = True
  18. torch._inductor.config.triton.unique_kernel_names = True
  19. if hasattr(torch._inductor.config, "fx_graph_cache"):
  20. # Experimental feature to reduce compilation times, will be on by default in future
  21. torch._inductor.config.fx_graph_cache = True
  22. from fish_speech.models.text2semantic.llama import Transformer
  23. from fish_speech.text import g2p
  24. from fish_speech.text.symbols import pad as pad_symbol
  25. from fish_speech.text.symbols import pu_symbols
  26. def multinomial_sample_one_no_sync(
  27. probs_sort,
  28. ): # Does multinomial sampling without a cuda synchronization
  29. q = torch.empty_like(probs_sort).exponential_(1)
  30. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  31. def logits_to_probs(
  32. logits,
  33. previous_tokens: Optional[torch.Tensor] = None,
  34. temperature: float = 1.0,
  35. top_k: Optional[int] = None,
  36. top_p: Optional[int] = None,
  37. repetition_penalty: float = 1.0,
  38. ):
  39. if previous_tokens is not None and repetition_penalty != 1.0:
  40. previous_tokens = previous_tokens.long()
  41. score = torch.gather(logits, dim=0, index=previous_tokens)
  42. score = torch.where(
  43. score < 0, score * repetition_penalty, score / repetition_penalty
  44. )
  45. logits.scatter_(dim=0, index=previous_tokens, src=score)
  46. if top_p is not None and top_p < 1.0:
  47. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  48. cum_probs = torch.cumsum(
  49. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  50. )
  51. sorted_indices_to_remove = cum_probs > top_p
  52. sorted_indices_to_remove[0] = False # keep at least one option
  53. indices_to_remove = sorted_indices_to_remove.scatter(
  54. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  55. )
  56. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  57. logits = logits / max(temperature, 1e-5)
  58. if top_k is not None:
  59. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  60. pivot = v.select(-1, -1).unsqueeze(-1)
  61. logits = torch.where(logits < pivot, -float("Inf"), logits)
  62. probs = torch.nn.functional.softmax(logits, dim=-1)
  63. return probs
  64. def sample(
  65. logits,
  66. previous_tokens: Optional[torch.Tensor] = None,
  67. **sampling_kwargs,
  68. ) -> Tuple[torch.Tensor, torch.Tensor]:
  69. probs = logits_to_probs(
  70. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  71. )
  72. idx_next = multinomial_sample_one_no_sync(probs)
  73. return idx_next, probs
  74. def decode_one_token(
  75. model: Transformer,
  76. x: torch.Tensor,
  77. input_pos: torch.Tensor,
  78. previous_tokens: torch.Tensor = None,
  79. **sampling_kwargs,
  80. ) -> torch.Tensor:
  81. assert input_pos.shape[-1] == 1
  82. x, logits = model.forward_generate_slow(x, input_pos)
  83. codebooks = [
  84. sample(
  85. logits,
  86. previous_tokens=None, # Disable repetition penalty for the token codebook
  87. **sampling_kwargs,
  88. )[0]
  89. ]
  90. # Cleanup the cache
  91. for layer in model.fast_layers:
  92. layer.attention.kv_cache.k_cache.fill_(0)
  93. layer.attention.kv_cache.v_cache.fill_(0)
  94. for codebook_idx in range(model.config.num_codebooks):
  95. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  96. logits = model.forward_generate_fast(x, input_pos)
  97. # print(x.shape, logits.shape)
  98. a = sample(
  99. logits,
  100. previous_tokens=(
  101. previous_tokens[codebook_idx + 1]
  102. if previous_tokens is not None
  103. else None
  104. ),
  105. **sampling_kwargs,
  106. )[0]
  107. x = model.fast_embeddings(a)
  108. codebooks.append(a)
  109. # x = torch.cat(buffer, dim=1)
  110. # logits = model.forward_fast(x)[:, -1:, :]
  111. # a = sample(
  112. # logits,
  113. # previous_tokens=(
  114. # previous_tokens[codebook_idx + 1]
  115. # if previous_tokens is not None
  116. # else None
  117. # ),
  118. # **sampling_kwargs,
  119. # )[0]
  120. # x = model.fast_embeddings(a)
  121. # codebooks.append(a)
  122. # buffer.append(x.view(1, 1, -1))
  123. return torch.stack(codebooks, dim=0)
  124. def prefill(
  125. model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
  126. ) -> torch.Tensor:
  127. # input_pos: [B, S]
  128. x, logits = model.forward_generate_slow(x, input_pos)
  129. codebooks = [
  130. sample(
  131. logits,
  132. previous_tokens=None,
  133. **sampling_kwargs,
  134. )[0]
  135. ]
  136. # Cleanup the cache
  137. for layer in model.fast_layers:
  138. layer.attention.kv_cache.k_cache.fill_(0)
  139. layer.attention.kv_cache.v_cache.fill_(0)
  140. for codebook_idx in range(model.config.num_codebooks):
  141. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  142. logits = model.forward_generate_fast(x, input_pos)
  143. # print(x.shape, logits.shape)
  144. a = sample(
  145. logits,
  146. previous_tokens=None,
  147. **sampling_kwargs,
  148. )[0]
  149. x = model.fast_embeddings(a)
  150. codebooks.append(a)
  151. # x = torch.cat(buffer, dim=1)
  152. # logits = model.forward_fast(x)[:, -1:, :]
  153. # a = sample(
  154. # logits,
  155. # **sampling_kwargs,
  156. # )[0]
  157. # x = model.fast_embeddings(a)
  158. # codebooks.append(a)
  159. # buffer.append(x.view(1, 1, -1))
  160. return torch.stack(codebooks, dim=0)
  161. def decode_n_tokens(
  162. model: Transformer,
  163. cur_token: torch.Tensor,
  164. input_pos: torch.Tensor,
  165. num_new_tokens: int,
  166. eos_token_id: int = 2,
  167. **sampling_kwargs,
  168. ):
  169. previous_tokens = torch.zeros(
  170. (model.config.num_codebooks + 1, model.config.max_seq_len),
  171. dtype=torch.int,
  172. device=cur_token.device,
  173. )
  174. for i in tqdm(range(num_new_tokens)):
  175. # We need to get windowed repeat penalty
  176. win_size = 16
  177. if i < win_size:
  178. window = previous_tokens[:, :win_size]
  179. else:
  180. window = previous_tokens[:, i - win_size : i]
  181. with torch.backends.cuda.sdp_kernel(
  182. enable_flash=False, enable_mem_efficient=False, enable_math=True
  183. ): # Actually better for Inductor to codegen attention here
  184. next_token = decode_one_token(
  185. model=model,
  186. x=cur_token,
  187. input_pos=input_pos,
  188. previous_tokens=window,
  189. **sampling_kwargs,
  190. )
  191. input_pos += 1
  192. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  193. previous_tokens[:, i : i + 1] = next_token.view(
  194. model.config.num_codebooks + 1, -1
  195. )
  196. # TODO: use tokenizer's eos
  197. if cur_token[0, 0, -1] == eos_token_id or (cur_token[0, 1:, -1] == 1).any():
  198. break
  199. return previous_tokens[:, : i + 1]
  200. @torch.no_grad()
  201. @torch.inference_mode()
  202. def generate(
  203. *,
  204. model: Transformer,
  205. prompt: torch.Tensor,
  206. max_new_tokens: int,
  207. eos_token_id: int = 2,
  208. precision: torch.dtype = torch.bfloat16,
  209. **sampling_kwargs,
  210. ) -> torch.Tensor:
  211. """
  212. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  213. """
  214. # create an empty tensor of the expected final shape and fill in the current tokens
  215. T = prompt.size(1)
  216. if max_new_tokens:
  217. if T + max_new_tokens > model.config.max_seq_len:
  218. max_new_tokens = model.config.max_seq_len - T
  219. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  220. T_new = T + max_new_tokens
  221. else:
  222. T_new = model.config.max_seq_len
  223. max_new_tokens = T_new - T
  224. device, dtype = prompt.device, prompt.dtype
  225. with torch.device(device):
  226. model.setup_caches(max_batch_size=1, max_seq_len=T_new, dtype=precision)
  227. codebook_dim = 1 + model.config.num_codebooks
  228. # create an empty tensor of the expected final shape and fill in the current tokens
  229. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  230. empty[:, :T] = prompt
  231. seq = empty
  232. input_pos = torch.arange(0, T, device=device)
  233. next_token = prefill(
  234. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  235. )
  236. seq[:, T : T + 1] = next_token
  237. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  238. x = decode_n_tokens(
  239. model,
  240. next_token.view(1, codebook_dim, -1),
  241. input_pos,
  242. max_new_tokens - 1,
  243. eos_token_id=eos_token_id,
  244. **sampling_kwargs,
  245. )
  246. # x = torch.cat(generated_tokens, dim=1)
  247. seq = seq[:, : T + 1 + x.size(1)]
  248. seq[:, T + 1 :] = x
  249. return seq
  250. def encode_tokens(
  251. tokenizer,
  252. string,
  253. bos=True,
  254. device="cuda",
  255. prompt_tokens=None,
  256. use_g2p=False,
  257. speaker=None,
  258. order="zh,jp,en",
  259. num_codebooks=4,
  260. ):
  261. if use_g2p:
  262. order = order.split(",")
  263. prompt = g2p(string, order=order)
  264. prompt = [
  265. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  266. for _, i in prompt
  267. ]
  268. string = " ".join(prompt)
  269. else:
  270. string = clean_text(string)
  271. if speaker is not None:
  272. string = f"[SPK: {speaker}] {string}"
  273. string = f"[INST] {string} [/INST]"
  274. # Handle English less frequent words
  275. # TODO: update tokenizer to handle this
  276. # sub_strings = string.split(" ")
  277. # new_tokens = []
  278. # for i, string in enumerate(sub_strings):
  279. # tokens = tokenizer.encode(
  280. # string,
  281. # add_special_tokens=i == 0 and bos,
  282. # max_length=10**6,
  283. # truncation=False,
  284. # )
  285. # new_tokens.extend(tokens)
  286. new_tokens = tokenizer.encode(
  287. string,
  288. add_special_tokens=bos,
  289. max_length=10**6,
  290. truncation=False,
  291. )
  292. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  293. # Codebooks
  294. zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  295. prompt = torch.cat((tokens, zeros), dim=0)
  296. if prompt_tokens is None:
  297. return prompt
  298. # Get prompt tokens
  299. if prompt_tokens.ndim == 3:
  300. assert (
  301. prompt_tokens.shape[0] == 1
  302. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  303. prompt_tokens = prompt_tokens[0]
  304. assert prompt_tokens.ndim == 2
  305. data = prompt_tokens + 2
  306. if prompt_tokens.shape[0] > num_codebooks:
  307. logger.warning(
  308. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  309. )
  310. data = data[:num_codebooks]
  311. # Since 1.0, we use <s:xxx> to replace <semantic>
  312. s0_token_id = tokenizer.convert_tokens_to_ids("<s:0>")
  313. main_token_ids = torch.tensor(
  314. # TODO: replace this
  315. [[s0_token_id] * data.size(1)],
  316. dtype=torch.int,
  317. device=device,
  318. )
  319. data = torch.cat((main_token_ids, data), dim=0)
  320. prompt = torch.cat((prompt, data), dim=1)
  321. return prompt
  322. def load_model(config_name, checkpoint_path, device, precision):
  323. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  324. cfg = compose(config_name=config_name)
  325. model: Transformer = instantiate(cfg.model).model
  326. if "int8" in str(checkpoint_path):
  327. logger.info("Using int8 weight-only quantization!")
  328. from quantize import WeightOnlyInt8QuantHandler
  329. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  330. model = simple_quantizer.convert_for_runtime()
  331. if "int4" in str(checkpoint_path):
  332. logger.info("Using int4 quantization!")
  333. path_comps = checkpoint_path.name.split(".")
  334. assert path_comps[-2].startswith("g")
  335. groupsize = int(path_comps[-2][1:])
  336. from quantize import WeightOnlyInt4QuantHandler
  337. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  338. model = simple_quantizer.convert_for_runtime()
  339. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  340. if "state_dict" in checkpoint:
  341. checkpoint = checkpoint["state_dict"]
  342. if any(k.startswith("model.") for k in checkpoint):
  343. checkpoint = {
  344. k.replace("model.", ""): v
  345. for k, v in checkpoint.items()
  346. if k.startswith("model.")
  347. }
  348. model.load_state_dict(checkpoint, assign=True)
  349. model = model.to(device=device, dtype=precision)
  350. logger.info("Restored model from checkpoint")
  351. return model.eval(), cfg
  352. def split_text(text, min_length):
  353. text = clean_text(text)
  354. segments = []
  355. curr = ""
  356. for char in text:
  357. curr += char
  358. if char not in [".", ",", "!", "?"]:
  359. continue
  360. if len(curr) >= min_length:
  361. segments.append(curr)
  362. curr = ""
  363. if curr:
  364. segments.append(curr)
  365. return segments
  366. @click.command()
  367. @click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
  368. @click.option("--prompt-text", type=str, default=None)
  369. @click.option(
  370. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  371. )
  372. @click.option("--num-samples", type=int, default=1)
  373. @click.option("--max-new-tokens", type=int, default=0)
  374. @click.option("--top-k", type=int, default=None)
  375. @click.option("--top-p", type=float, default=0.9)
  376. @click.option("--repetition-penalty", type=float, default=1.2)
  377. @click.option("--temperature", type=float, default=0.7)
  378. @click.option(
  379. "--checkpoint-path",
  380. type=click.Path(path_type=Path, exists=True),
  381. default="results/text2semantic_400m_finetune/step_000002000.pth",
  382. )
  383. @click.option("--config-name", type=str, default="text2semantic_finetune")
  384. @click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
  385. @click.option("--compile/--no-compile", default=False)
  386. @click.option("--use-g2p/--no-g2p", default=True)
  387. @click.option("--seed", type=int, default=42)
  388. @click.option("--speaker", type=str, default=None)
  389. @click.option("--order", type=str, default="zh,jp,en")
  390. @click.option("--half/--no-half", default=False)
  391. @click.option("--iterative-prompt/--no-iterative-prompt", default=False)
  392. def main(
  393. text: str,
  394. prompt_text: Optional[str],
  395. prompt_tokens: Optional[Path],
  396. num_samples: int,
  397. max_new_tokens: int,
  398. top_k: int,
  399. top_p: int,
  400. repetition_penalty: float,
  401. temperature: float,
  402. checkpoint_path: Path,
  403. config_name: str,
  404. tokenizer: str,
  405. compile: bool,
  406. use_g2p: bool,
  407. seed: int,
  408. speaker: Optional[str],
  409. order: str,
  410. half: bool,
  411. iterative_prompt: bool,
  412. ) -> None:
  413. device = "cuda"
  414. precision = torch.half if half else torch.bfloat16
  415. logger.info("Loading model ...")
  416. t0 = time.time()
  417. model, cfg = load_model(config_name, checkpoint_path, device, precision)
  418. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  419. torch.cuda.synchronize()
  420. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  421. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  422. prompt_tokens = (
  423. torch.from_numpy(np.load(prompt_tokens)).to(device)
  424. if prompt_tokens is not None
  425. else None
  426. )
  427. use_prompt = prompt_text is not None and prompt_tokens is not None
  428. encoded = []
  429. texts = split_text(text, 30) if iterative_prompt else [text]
  430. for idx, text in enumerate(texts):
  431. encoded.append(
  432. encode_tokens(
  433. tokenizer,
  434. string=text,
  435. bos=idx == 0 and not use_prompt,
  436. device=device,
  437. use_g2p=use_g2p,
  438. speaker=None,
  439. order=order,
  440. num_codebooks=model.config.num_codebooks,
  441. )
  442. )
  443. print(f"Encoded text: {text}")
  444. if use_prompt:
  445. encoded_prompt = encode_tokens(
  446. tokenizer,
  447. prompt_text,
  448. prompt_tokens=prompt_tokens,
  449. bos=True,
  450. device=device,
  451. use_g2p=use_g2p,
  452. speaker=speaker,
  453. order=order,
  454. num_codebooks=model.config.num_codebooks,
  455. )
  456. encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
  457. # prompt_length = encoded.size(1)
  458. # logger.info(f"Encoded prompt shape: {encoded.shape}")
  459. torch.manual_seed(seed)
  460. torch.cuda.manual_seed(seed)
  461. if compile:
  462. global decode_one_token
  463. decode_one_token = torch.compile(
  464. decode_one_token, mode="reduce-overhead", fullgraph=True
  465. )
  466. for idx in range(num_samples):
  467. torch.cuda.synchronize()
  468. global_encoded = []
  469. all_codes = []
  470. seg_idx = 0
  471. while seg_idx < len(encoded):
  472. seg = encoded[seg_idx]
  473. global_encoded.append(seg)
  474. lengths = reversed([seg.size(1) for seg in global_encoded])
  475. # Pick last 2000 tokens
  476. count = 0
  477. for i, length in enumerate(lengths):
  478. count += length
  479. if count >= 2000:
  480. break
  481. if i != 0 and i % 2 == 0:
  482. i -= 1
  483. if i < len(global_encoded) - 2:
  484. partial_encoded = global_encoded[-i:]
  485. print(f"Loaded partial encoded")
  486. else:
  487. partial_encoded = global_encoded
  488. print(f"Using full encoded")
  489. cat_encoded = torch.cat(partial_encoded, dim=1)
  490. prompt_length = cat_encoded.size(1)
  491. t0 = time.perf_counter()
  492. y = generate(
  493. model=model,
  494. prompt=cat_encoded,
  495. max_new_tokens=max_new_tokens,
  496. eos_token_id=tokenizer.eos_token_id,
  497. precision=precision,
  498. temperature=temperature,
  499. top_k=top_k,
  500. top_p=top_p,
  501. repetition_penalty=repetition_penalty,
  502. )
  503. if idx == 0 and compile:
  504. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  505. torch.cuda.synchronize()
  506. t = time.perf_counter() - t0
  507. tokens_generated = y.size(1) - prompt_length
  508. tokens_sec = tokens_generated / t
  509. logger.info(
  510. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  511. )
  512. logger.info(
  513. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  514. )
  515. logger.info(
  516. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  517. )
  518. # Put the generated tokens
  519. codes = y[1:, prompt_length:-1].clone()
  520. # if getattr(cfg, "use_delay_pattern", True):
  521. # new_codes = []
  522. # for j, code in enumerate(codes):
  523. # new_codes.append(
  524. # code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
  525. # )
  526. # codes = torch.stack(new_codes, dim=0)
  527. codes = codes - 2
  528. if not (codes >= 0).all():
  529. global_encoded.pop()
  530. logger.warning(f"Negative code found: {codes}, retrying ...")
  531. continue
  532. global_encoded.append(y[:, prompt_length:-1].clone())
  533. all_codes.append(codes)
  534. seg_idx += 1
  535. codes = torch.cat(all_codes, dim=1)
  536. assert (codes >= 0).all(), f"Negative code found: {codes}"
  537. print(codes)
  538. np.save(f"codes_{idx}.npy", codes.cpu().numpy())
  539. logger.info(f"Saved codes to codes_{idx}.npy")
  540. if __name__ == "__main__":
  541. main()