generate.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. import os
  2. import time
  3. from pathlib import Path
  4. from typing import Optional, Tuple
  5. import click
  6. import numpy as np
  7. import torch
  8. import torch._dynamo.config
  9. import torch._inductor.config
  10. from hydra import compose, initialize
  11. from hydra.utils import instantiate
  12. from loguru import logger
  13. from tqdm import tqdm
  14. from transformers import AutoTokenizer
  15. from fish_speech.text.parser import clean_text
  16. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  17. torch._inductor.config.coordinate_descent_tuning = True
  18. torch._inductor.config.triton.unique_kernel_names = True
  19. torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
  20. from fish_speech.models.text2semantic.llama import Transformer
  21. from fish_speech.text import g2p
  22. from fish_speech.text.symbols import pad as pad_symbol
  23. from fish_speech.text.symbols import pu_symbols
  24. def multinomial_sample_one_no_sync(
  25. probs_sort,
  26. ): # Does multinomial sampling without a cuda synchronization
  27. q = torch.empty_like(probs_sort).exponential_(1)
  28. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  29. def logits_to_probs(
  30. logits,
  31. previous_tokens: Optional[torch.Tensor] = None,
  32. temperature: float = 1.0,
  33. top_k: Optional[int] = None,
  34. top_p: Optional[int] = None,
  35. repetition_penalty: float = 1.0,
  36. ):
  37. if previous_tokens is not None and repetition_penalty != 1.0:
  38. previous_tokens = previous_tokens.long()
  39. score = torch.gather(logits, dim=0, index=previous_tokens)
  40. score = torch.where(
  41. score < 0, score * repetition_penalty, score / repetition_penalty
  42. )
  43. logits.scatter_(dim=0, index=previous_tokens, src=score)
  44. if top_p is not None and top_p < 1.0:
  45. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  46. cum_probs = torch.cumsum(
  47. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  48. )
  49. sorted_indices_to_remove = cum_probs > top_p
  50. sorted_indices_to_remove[0] = False # keep at least one option
  51. indices_to_remove = sorted_indices_to_remove.scatter(
  52. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  53. )
  54. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  55. logits = logits / max(temperature, 1e-5)
  56. if top_k is not None:
  57. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  58. pivot = v.select(-1, -1).unsqueeze(-1)
  59. logits = torch.where(logits < pivot, -float("Inf"), logits)
  60. probs = torch.nn.functional.softmax(logits, dim=-1)
  61. return probs
  62. def sample(
  63. logits,
  64. previous_tokens: Optional[torch.Tensor] = None,
  65. **sampling_kwargs,
  66. ) -> Tuple[torch.Tensor, torch.Tensor]:
  67. probs = logits_to_probs(
  68. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  69. )
  70. idx_next = multinomial_sample_one_no_sync(probs)
  71. return idx_next, probs
  72. def decode_one_token(
  73. model: Transformer,
  74. x: torch.Tensor,
  75. input_pos: torch.Tensor,
  76. previous_tokens: torch.Tensor = None,
  77. **sampling_kwargs,
  78. ) -> torch.Tensor:
  79. assert input_pos.shape[-1] == 1
  80. logits = model.forward_generate(x, input_pos)
  81. codebooks = [
  82. sample(
  83. logits.token_logits,
  84. previous_tokens=None, # Disable repetition penalty for the token codebook
  85. **sampling_kwargs,
  86. )[0]
  87. ]
  88. # Disable <s> and </s> tokens for codebooks
  89. if model.config.num_codebooks != 0:
  90. logits.codebook_logits[:, :, :, :2] = -float("Inf")
  91. for i in range(model.config.num_codebooks):
  92. codebooks.append(
  93. sample(
  94. logits.codebook_logits[:, :, i],
  95. previous_tokens=previous_tokens[i + 1]
  96. if previous_tokens is not None
  97. else None,
  98. **sampling_kwargs,
  99. )[0]
  100. )
  101. return torch.stack(codebooks, dim=0)
  102. def prefill(
  103. model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
  104. ) -> torch.Tensor:
  105. # input_pos: [B, S]
  106. logits = model.forward_generate(x, input_pos)
  107. codebooks = [
  108. sample(
  109. logits.token_logits,
  110. previous_tokens=None,
  111. **sampling_kwargs,
  112. )[0]
  113. ]
  114. # Disable <s> and </s> tokens for codebooks
  115. if model.config.num_codebooks != 0:
  116. logits.codebook_logits[:, :, :, :2] = -float("Inf")
  117. for i in range(model.config.num_codebooks):
  118. codebooks.append(
  119. sample(
  120. logits.codebook_logits[:, :, i],
  121. previous_tokens=None,
  122. **sampling_kwargs,
  123. )[0]
  124. )
  125. return torch.stack(codebooks, dim=0)
  126. def decode_n_tokens(
  127. model: Transformer,
  128. cur_token: torch.Tensor,
  129. input_pos: torch.Tensor,
  130. num_new_tokens: int,
  131. eos_token_id: int = 2,
  132. **sampling_kwargs,
  133. ):
  134. previous_tokens = torch.zeros(
  135. (model.config.num_codebooks + 1, num_new_tokens),
  136. dtype=torch.int,
  137. device=cur_token.device,
  138. )
  139. for i in tqdm(range(num_new_tokens)):
  140. # We need to get windowed repeat penalty
  141. win_size = 16
  142. if i < win_size:
  143. window = previous_tokens[:, :win_size]
  144. else:
  145. window = previous_tokens[:, i - win_size : i]
  146. with torch.backends.cuda.sdp_kernel(
  147. enable_flash=False, enable_mem_efficient=False, enable_math=True
  148. ): # Actually better for Inductor to codegen attention here
  149. next_token = decode_one_token(
  150. model,
  151. cur_token,
  152. input_pos,
  153. window,
  154. **sampling_kwargs,
  155. )
  156. input_pos += 1
  157. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  158. previous_tokens[:, i : i + 1] = next_token.view(
  159. model.config.num_codebooks + 1, -1
  160. )
  161. # TODO: use tokenizer's eos
  162. if (cur_token[0, 0, -1] == eos_token_id).any():
  163. break
  164. return previous_tokens[:, : i + 1]
  165. @torch.no_grad()
  166. def generate(
  167. *,
  168. model: Transformer,
  169. prompt: torch.Tensor,
  170. max_new_tokens: int,
  171. eos_token_id: int = 2,
  172. **sampling_kwargs,
  173. ) -> torch.Tensor:
  174. """
  175. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  176. """
  177. # create an empty tensor of the expected final shape and fill in the current tokens
  178. T = prompt.size(1)
  179. if max_new_tokens:
  180. if T + max_new_tokens > model.config.max_seq_len:
  181. max_new_tokens = model.config.max_seq_len - T
  182. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  183. T_new = T + max_new_tokens
  184. else:
  185. T_new = model.config.max_seq_len
  186. max_new_tokens = T_new - T
  187. device, dtype = prompt.device, prompt.dtype
  188. with torch.device(device):
  189. model.setup_caches(max_batch_size=1, max_seq_len=T_new)
  190. codebook_dim = 1 + model.config.num_codebooks
  191. # create an empty tensor of the expected final shape and fill in the current tokens
  192. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  193. empty[:, :T] = prompt
  194. seq = empty
  195. input_pos = torch.arange(0, T, device=device)
  196. next_token = prefill(
  197. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  198. )
  199. seq[:, T : T + 1] = next_token
  200. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  201. x = decode_n_tokens(
  202. model,
  203. next_token.view(1, codebook_dim, -1),
  204. input_pos,
  205. max_new_tokens - 1,
  206. eos_token_id=eos_token_id,
  207. **sampling_kwargs,
  208. )
  209. # x = torch.cat(generated_tokens, dim=1)
  210. seq = seq[:, : T + 1 + x.size(1)]
  211. seq[:, T + 1 :] = x
  212. return seq
  213. def encode_tokens(
  214. tokenizer,
  215. string,
  216. bos=True,
  217. device="cuda",
  218. prompt_text=None,
  219. prompt_tokens=None,
  220. use_g2p=False,
  221. ):
  222. if prompt_text is not None:
  223. string = prompt_text + " " + string
  224. if use_g2p:
  225. prompt = g2p(string)
  226. prompt = [
  227. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  228. for _, i in prompt
  229. ]
  230. string = " ".join(prompt)
  231. else:
  232. string = clean_text(string)
  233. string = f"[INST] {string} [/INST]"
  234. tokens = tokenizer.encode(
  235. string,
  236. max_length=10**6,
  237. add_special_tokens=bos,
  238. truncation=False,
  239. )
  240. tokens = torch.tensor([tokens], dtype=torch.int, device=device)
  241. # Codebooks
  242. zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
  243. prompt = torch.cat((tokens, zeros), dim=0)
  244. if prompt_tokens is None:
  245. return prompt
  246. # Get prompt tokens
  247. assert prompt_tokens.ndim == 2
  248. data = prompt_tokens + 2
  249. zeros = (
  250. torch.zeros((1, data.size(1)), dtype=torch.int, device=device)
  251. + tokenizer.pad_token_id
  252. ) # 32311 is the <pad> token
  253. data = torch.cat((zeros, data), dim=0)
  254. prompt = torch.cat((prompt, data), dim=1)
  255. return prompt
  256. def load_model(config_name, checkpoint_path, device, precision):
  257. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  258. cfg = compose(config_name=config_name)
  259. with torch.device("meta"):
  260. model: Transformer = instantiate(cfg.model.model)
  261. if "int8" in str(checkpoint_path):
  262. logger.info("Using int8 weight-only quantization!")
  263. from quantize import WeightOnlyInt8QuantHandler
  264. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  265. model = simple_quantizer.convert_for_runtime()
  266. if "int4" in str(checkpoint_path):
  267. logger.info("Using int4 quantization!")
  268. path_comps = checkpoint_path.name.split(".")
  269. assert path_comps[-2].startswith("g")
  270. groupsize = int(path_comps[-2][1:])
  271. from quantize import WeightOnlyInt4QuantHandler
  272. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  273. model = simple_quantizer.convert_for_runtime()
  274. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  275. if "state_dict" in checkpoint:
  276. checkpoint = checkpoint["state_dict"]
  277. if any(k.startswith("model.") for k in checkpoint):
  278. checkpoint = {
  279. k.replace("model.", ""): v
  280. for k, v in checkpoint.items()
  281. if k.startswith("model.")
  282. }
  283. model.load_state_dict(checkpoint, assign=True)
  284. model = model.to(device=device, dtype=precision)
  285. logger.info("Restored model from checkpoint")
  286. return model.eval()
  287. @click.command()
  288. @click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
  289. @click.option("--prompt-text", type=str, default=None)
  290. @click.option(
  291. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  292. )
  293. @click.option("--num-samples", type=int, default=1)
  294. @click.option("--max_new_tokens", type=int, default=0)
  295. @click.option("--top-k", type=int, default=None)
  296. @click.option("--top-p", type=float, default=0.5)
  297. @click.option("--repetition-penalty", type=float, default=1.5)
  298. @click.option("--temperature", type=float, default=0.7)
  299. @click.option(
  300. "--checkpoint-path",
  301. type=click.Path(path_type=Path, exists=True),
  302. default="results/text2semantic_400m_finetune/step_000002000.pth",
  303. )
  304. @click.option("--config-name", type=str, default="text2semantic_finetune")
  305. @click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
  306. @click.option("--compile/--no-compile", default=False)
  307. @click.option("--use-g2p/--no-g2p", default=True)
  308. @click.option("--seed", type=int, default=42)
  309. def main(
  310. text: str,
  311. prompt_text: Optional[str],
  312. prompt_tokens: Optional[Path],
  313. num_samples: int,
  314. max_new_tokens: int,
  315. top_k: int,
  316. top_p: int,
  317. repetition_penalty: float,
  318. temperature: float,
  319. checkpoint_path: Path,
  320. config_name: str,
  321. tokenizer: str,
  322. compile: bool,
  323. use_g2p: bool,
  324. seed: int,
  325. ) -> None:
  326. device = "cuda"
  327. precision = torch.bfloat16
  328. logger.info("Loading model ...")
  329. t0 = time.time()
  330. model = load_model(config_name, checkpoint_path, device, precision)
  331. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  332. torch.cuda.synchronize()
  333. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  334. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  335. prompt_tokens = (
  336. torch.from_numpy(np.load(prompt_tokens)).to(device)
  337. if prompt_tokens is not None
  338. else None
  339. )
  340. encoded = encode_tokens(
  341. tokenizer,
  342. text,
  343. prompt_text=prompt_text,
  344. prompt_tokens=prompt_tokens,
  345. bos=True,
  346. device=device,
  347. use_g2p=use_g2p,
  348. )
  349. prompt_length = encoded.size(1)
  350. logger.info(f"Encoded prompt shape: {encoded.shape}")
  351. torch.manual_seed(seed)
  352. if compile:
  353. global decode_one_token
  354. decode_one_token = torch.compile(
  355. decode_one_token, mode="reduce-overhead", fullgraph=True
  356. )
  357. for i in range(num_samples):
  358. torch.cuda.synchronize()
  359. t0 = time.perf_counter()
  360. y = generate(
  361. model=model,
  362. prompt=encoded,
  363. max_new_tokens=max_new_tokens,
  364. eos_token_id=tokenizer.eos_token_id,
  365. temperature=temperature,
  366. top_k=top_k,
  367. top_p=top_p,
  368. repetition_penalty=repetition_penalty,
  369. )
  370. if i == 0 and compile:
  371. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  372. torch.cuda.synchronize()
  373. t = time.perf_counter() - t0
  374. tokens_generated = y.size(1) - prompt_length
  375. tokens_sec = tokens_generated / t
  376. logger.info(
  377. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  378. )
  379. logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
  380. logger.info(
  381. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  382. )
  383. codes = y[1:, prompt_length:-1]
  384. codes = codes - 2
  385. assert (codes >= 0).all(), "Codes should be >= 0"
  386. np.save(f"codes_{i}.npy", codes.cpu().numpy())
  387. logger.info(f"Saved codes to codes_{i}.npy")
  388. if __name__ == "__main__":
  389. main()