generate.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import itertools
  6. import sys
  7. import time
  8. from pathlib import Path
  9. from typing import Optional, Tuple
  10. import numpy as np
  11. import torch
  12. import torch._dynamo.config
  13. import torch._inductor.config
  14. from transformers import AutoTokenizer
  15. torch._inductor.config.coordinate_descent_tuning = True
  16. torch._inductor.config.triton.unique_kernel_names = True
  17. torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
  18. from fish_speech.models.text2semantic.llama import ModelArgs, Transformer
  19. from fish_speech.text import g2p
  20. from fish_speech.text.symbols import pad as pad_symbol
  21. from fish_speech.text.symbols import pu_symbols
  22. from tools.llama.tp import maybe_init_dist
  23. def multinomial_sample_one_no_sync(
  24. probs_sort,
  25. ): # Does multinomial sampling without a cuda synchronization
  26. q = torch.empty_like(probs_sort).exponential_(1)
  27. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  28. def logits_to_probs(
  29. logits,
  30. previous_tokens: Optional[torch.Tensor] = None,
  31. temperature: float = 1.0,
  32. top_k: Optional[int] = None,
  33. top_p: Optional[int] = None,
  34. repetition_penalty: float = 1.0,
  35. ):
  36. if previous_tokens is not None and repetition_penalty != 1.0:
  37. previous_tokens = previous_tokens.long()
  38. score = torch.gather(logits, dim=-1, index=previous_tokens)
  39. score = torch.where(
  40. score < 0, score * repetition_penalty, score / repetition_penalty
  41. )
  42. logits.scatter_(dim=-1, index=previous_tokens, src=score)
  43. if top_p is not None and top_p < 1.0:
  44. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  45. cum_probs = torch.cumsum(
  46. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  47. )
  48. sorted_indices_to_remove = cum_probs > top_p
  49. sorted_indices_to_remove[0] = False # keep at least one option
  50. indices_to_remove = sorted_indices_to_remove.scatter(
  51. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  52. )
  53. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  54. logits = logits / max(temperature, 1e-5)
  55. if top_k is not None:
  56. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  57. pivot = v.select(-1, -1).unsqueeze(-1)
  58. logits = torch.where(logits < pivot, -float("Inf"), logits)
  59. probs = torch.nn.functional.softmax(logits, dim=-1)
  60. return probs
  61. def sample(
  62. logits,
  63. previous_tokens: Optional[torch.Tensor] = None,
  64. temperature: float = 1.0,
  65. top_k: Optional[int] = None,
  66. top_p: Optional[int] = None,
  67. repetition_penalty: float = 1.0,
  68. ) -> Tuple[torch.Tensor, torch.Tensor]:
  69. probs = logits_to_probs(
  70. logits[0, -1], previous_tokens, temperature, top_k, top_p, repetition_penalty
  71. )
  72. idx_next = multinomial_sample_one_no_sync(probs)
  73. return idx_next, probs
  74. def decode_token(
  75. model: Transformer,
  76. x: torch.Tensor,
  77. input_pos: torch.Tensor,
  78. previous_tokens: torch.Tensor = None,
  79. **sampling_kwargs,
  80. ) -> torch.Tensor:
  81. # input_pos: [B, S]
  82. logits = model.forward_generate(x, input_pos)
  83. codebooks = [
  84. sample(
  85. logits.token_logits,
  86. previous_tokens=previous_tokens[0] if previous_tokens is not None else None,
  87. **sampling_kwargs,
  88. )[0]
  89. ]
  90. # Disable <s> and </s> tokens for codebooks
  91. if model.config.num_codebooks != 0:
  92. logits.codebook_logits[:, :, :, :2] = -float("Inf")
  93. for i in range(model.config.num_codebooks):
  94. codebooks.append(
  95. sample(
  96. logits.codebook_logits[:, :, i],
  97. previous_tokens=previous_tokens[i]
  98. if previous_tokens is not None
  99. else None,
  100. **sampling_kwargs,
  101. )[0]
  102. )
  103. return torch.stack(codebooks, dim=0)
  104. def decode_n_tokens(
  105. model: Transformer,
  106. cur_token: torch.Tensor,
  107. input_pos: torch.Tensor,
  108. num_new_tokens: int,
  109. callback=lambda _: _,
  110. **sampling_kwargs,
  111. ):
  112. new_tokens = []
  113. for i in range(num_new_tokens):
  114. with torch.backends.cuda.sdp_kernel(
  115. enable_flash=False, enable_mem_efficient=False, enable_math=True
  116. ): # Actually better for Inductor to codegen attention here
  117. next_token = decode_token(
  118. model,
  119. cur_token,
  120. input_pos,
  121. torch.concat(new_tokens, dim=1) if len(new_tokens) > 0 else None,
  122. **sampling_kwargs,
  123. )
  124. input_pos += 1
  125. new_tokens.append(next_token.clone())
  126. callback(new_tokens[-1])
  127. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  128. # TODO: use tokenizer's eos
  129. if (cur_token[0, 0, -1] == 2).any():
  130. print("EOS detected, stopping generation")
  131. break
  132. return new_tokens
  133. def model_forward(model, x, input_pos):
  134. return model(x, input_pos)
  135. @torch.no_grad()
  136. def generate(
  137. model: Transformer,
  138. prompt: torch.Tensor,
  139. max_new_tokens: int,
  140. *,
  141. interactive: bool,
  142. callback=lambda x: x,
  143. **sampling_kwargs,
  144. ) -> torch.Tensor:
  145. """
  146. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  147. """
  148. # create an empty tensor of the expected final shape and fill in the current tokens
  149. T = prompt.size(1)
  150. if T + max_new_tokens > model.config.max_seq_len:
  151. max_new_tokens = model.config.max_seq_len - T
  152. print(f"Truncating max_new_tokens to {max_new_tokens}")
  153. T_new = T + max_new_tokens
  154. if interactive:
  155. max_seq_length = 350
  156. else:
  157. max_seq_length = min(T_new, model.config.max_seq_len)
  158. device, dtype = prompt.device, prompt.dtype
  159. max_seq_length = max_seq_length
  160. with torch.device(device):
  161. model.setup_caches(max_batch_size=1, max_seq_len=max_seq_length)
  162. codebook_dim = 1 + model.config.num_codebooks
  163. # create an empty tensor of the expected final shape and fill in the current tokens
  164. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  165. empty[:, :T] = prompt
  166. seq = empty
  167. input_pos = torch.arange(0, T, device=device)
  168. next_token = decode_token(
  169. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  170. )
  171. seq[:, T : T + 1] = next_token
  172. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  173. generated_tokens = decode_n_tokens(
  174. model,
  175. next_token.view(1, codebook_dim, -1),
  176. input_pos,
  177. max_new_tokens - 1,
  178. callback=callback,
  179. **sampling_kwargs,
  180. )
  181. x = torch.cat(generated_tokens, dim=1)
  182. seq = seq[:, : T + 1 + x.size(1)]
  183. seq[:, T + 1 :] = x
  184. return seq
  185. def encode_tokens(tokenizer, string, bos=True, device="cuda"):
  186. # data/Genshin/Chinese/神里绫华/vo_ayaka_character_idle_04.npy
  187. prompt = g2p("<zh>算啦,虽然他罪无可恕,但也有可怜的地方嘛。</zh> {string}")
  188. prompt = [
  189. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  190. for _, i in prompt
  191. ]
  192. prompt = " ".join(prompt)
  193. string = f"[INST] {prompt} [/INST]"
  194. print("Encoding string:", string)
  195. data = np.load("data/Genshin/Chinese/派蒙/vo_WYLQ103_10_paimon_03.npy")
  196. codes = [f"<s:{i}>" for i in data[0]]
  197. tokens = tokenizer.encode(
  198. string + " ".join(codes),
  199. max_length=10**6,
  200. add_special_tokens=bos,
  201. truncation=False,
  202. )
  203. tokens = torch.tensor([tokens], dtype=torch.int, device=device)
  204. # Codebooks
  205. # zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
  206. # prompt = torch.cat((tokens, zeros), dim=0)
  207. # # Get prompt tokens
  208. # data = np.load("data/Genshin/Chinese/神里绫华/vo_ayaka_character_idle_02.npy")
  209. # data = torch.from_numpy(data).to(device=device, dtype=torch.int) + 2
  210. # zeros = torch.zeros((1, data.size(1)), dtype=torch.int, device=device) + 32311 # 32311 is the <pad> token
  211. # data = torch.cat((zeros, data), dim=0)
  212. # prompt = torch.cat((prompt, data), dim=1)
  213. # print(prompt)
  214. return tokens
  215. def _load_model(checkpoint_path, device, precision, use_tp):
  216. with torch.device("meta"):
  217. # TODO: support different model archs
  218. model = Transformer(
  219. ModelArgs(
  220. max_seq_len=4096,
  221. vocab_size=36408,
  222. n_layer=24,
  223. n_head=16,
  224. dim=1024,
  225. rope_base=10000,
  226. norm_eps=1e-5,
  227. codebook_size=168,
  228. num_codebooks=0,
  229. )
  230. )
  231. if "int8" in str(checkpoint_path):
  232. print("Using int8 weight-only quantization!")
  233. from quantize import WeightOnlyInt8QuantHandler
  234. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  235. model = simple_quantizer.convert_for_runtime()
  236. if "int4" in str(checkpoint_path):
  237. print("Using int4 quantization!")
  238. path_comps = checkpoint_path.name.split(".")
  239. assert path_comps[-2].startswith("g")
  240. groupsize = int(path_comps[-2][1:])
  241. from quantize import WeightOnlyInt4QuantHandler
  242. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  243. model = simple_quantizer.convert_for_runtime()
  244. checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
  245. model.load_state_dict(checkpoint, assign=True)
  246. if use_tp:
  247. from tp import apply_tp
  248. print("Applying tensor parallel to model ...")
  249. apply_tp(model)
  250. model = model.to(device=device, dtype=precision)
  251. return model.eval()
  252. B_INST, E_INST = "[INST]", "[/INST]"
  253. def main(
  254. prompt: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  255. interactive: bool = False,
  256. num_samples: int = 5,
  257. max_new_tokens: int = 100,
  258. top_k: int = None,
  259. top_p: int = 1.0,
  260. repetition_penalty: float = 1.0,
  261. temperature: float = 0.8,
  262. checkpoint_path: Path = Path(
  263. "results/text2semantic_400m/checkpoints/step_000025000.ckpt"
  264. ),
  265. compile: bool = True,
  266. profile: Optional[Path] = None,
  267. tokenizer: str = "fishaudio/speech-lm-v1",
  268. ) -> None:
  269. """Generates text samples based on a pre-trained Transformer model and tokenizer."""
  270. assert checkpoint_path.is_file(), checkpoint_path
  271. global print
  272. rank = maybe_init_dist()
  273. use_tp = rank is not None
  274. if use_tp:
  275. torch.cuda.set_device(rank)
  276. if rank != 0:
  277. # only print on rank 0
  278. print = lambda *args, **kwargs: None
  279. device = "cuda"
  280. precision = torch.bfloat16
  281. print("Loading model ...")
  282. t0 = time.time()
  283. model = _load_model(checkpoint_path, device, precision, use_tp)
  284. torch.cuda.synchronize()
  285. print(f"Time to load model: {time.time() - t0:.02f} seconds")
  286. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  287. print(prompt)
  288. encoded = encode_tokens(tokenizer, f"{prompt}", bos=True, device=device)
  289. prompt_length = encoded.size(1)
  290. torch.manual_seed(1234)
  291. model_size = sum(
  292. [
  293. p.numel() * p.dtype.itemsize
  294. for p in itertools.chain(model.parameters(), model.buffers())
  295. ]
  296. )
  297. if compile:
  298. global decode_token
  299. decode_token = torch.compile(
  300. decode_token, mode="reduce-overhead", fullgraph=True
  301. )
  302. aggregate_metrics = {
  303. "tokens_per_sec": [],
  304. }
  305. start = -1 if compile else 0
  306. for i in range(start, num_samples):
  307. torch.cuda.synchronize()
  308. if i >= 0 and interactive:
  309. prompt = input("What is your prompt? ")
  310. encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
  311. if interactive and i >= 0:
  312. buffer = []
  313. period_id = tokenizer.encode(".")[0]
  314. done_generating = False
  315. def callback(x):
  316. nonlocal done_generating
  317. if done_generating:
  318. return
  319. buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
  320. if x.item() == tokenizer.eos_id():
  321. done_generating = True
  322. if len(buffer) == 4 or done_generating:
  323. print("".join(buffer), end="", flush=True)
  324. buffer.clear()
  325. # print(, end='', flush=True)
  326. else:
  327. callback = lambda x: x
  328. t0 = time.perf_counter()
  329. import contextlib
  330. if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
  331. prof = contextlib.nullcontext()
  332. else:
  333. torch.profiler._utils._init_for_cuda_graphs()
  334. prof = torch.profiler.profile()
  335. with prof:
  336. y = generate(
  337. model,
  338. encoded,
  339. max_new_tokens,
  340. interactive=interactive,
  341. callback=callback,
  342. temperature=temperature,
  343. top_k=top_k,
  344. top_p=top_p,
  345. repetition_penalty=repetition_penalty,
  346. )
  347. if i == -1:
  348. print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  349. continue
  350. if hasattr(prof, "export_chrome_trace"):
  351. if use_tp:
  352. prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
  353. else:
  354. prof.export_chrome_trace(f"{profile}.json")
  355. torch.cuda.synchronize()
  356. t = time.perf_counter() - t0
  357. if not interactive:
  358. print(tokenizer.decode(y[0, :prompt_length:].tolist()))
  359. print(f"Generated {y.size(1) - prompt_length} tokens")
  360. # Find all <s:2769>
  361. codes = y[0, prompt_length:-1]
  362. codes = codes - 32311
  363. # print(codes)
  364. assert (codes >= 0).all()
  365. import numpy as np
  366. np.save(f"codes_{i}.npy", codes[None].cpu().numpy())
  367. else:
  368. print()
  369. tokens_generated = y.size(1) - prompt_length
  370. tokens_sec = tokens_generated / t
  371. aggregate_metrics["tokens_per_sec"].append(tokens_sec)
  372. print(
  373. f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
  374. )
  375. print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
  376. print("==========")
  377. print(
  378. f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}"
  379. )
  380. print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
  381. if __name__ == "__main__":
  382. import argparse
  383. parser = argparse.ArgumentParser(description="Your CLI description.")
  384. parser.add_argument(
  385. "--prompt",
  386. type=str,
  387. default="感情分析関数では、大規模言語モデルに古典的な散文を分析させます。 分析の観点は比較的単純ですが、論理的な誤りはなく、依然として自己一貫性があることがわかります。",
  388. help="Input prompt.",
  389. )
  390. parser.add_argument(
  391. "--interactive",
  392. action="store_true",
  393. help="Whether to launch in interactive mode",
  394. )
  395. parser.add_argument("--num_samples", type=int, default=1, help="Number of samples.")
  396. parser.add_argument(
  397. "--max_new_tokens", type=int, default=4096, help="Maximum number of new tokens."
  398. )
  399. parser.add_argument("--top_k", type=int, default=50, help="Top-k for sampling.")
  400. parser.add_argument("--top_p", type=int, default=0.95, help="Top-k for sampling.")
  401. parser.add_argument("--repetition_penalty", type=float, default=1.0)
  402. parser.add_argument(
  403. "--temperature", type=float, default=0.8, help="Temperature for sampling."
  404. )
  405. parser.add_argument(
  406. "--checkpoint_path",
  407. type=Path,
  408. default=Path("results/text2semantic_400m/step_000095000_weights.ckpt"),
  409. help="Model checkpoint path.",
  410. )
  411. parser.add_argument(
  412. "--compile", action="store_true", help="Whether to compile the model."
  413. )
  414. parser.add_argument("--profile", type=Path, default=None, help="Profile path.")
  415. args = parser.parse_args()
  416. main(
  417. args.prompt,
  418. args.interactive,
  419. args.num_samples,
  420. args.max_new_tokens,
  421. args.top_k,
  422. args.top_p,
  423. args.repetition_penalty,
  424. args.temperature,
  425. args.checkpoint_path,
  426. args.compile,
  427. args.profile,
  428. )