generate.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import itertools
  6. import sys
  7. import time
  8. from pathlib import Path
  9. from typing import Optional, Tuple
  10. import torch
  11. import torch._dynamo.config
  12. import torch._inductor.config
  13. from transformers import AutoTokenizer
  14. torch._inductor.config.coordinate_descent_tuning = True
  15. torch._inductor.config.triton.unique_kernel_names = True
  16. torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
  17. from fish_speech.models.text2semantic.llama import ModelArgs, Transformer
  18. from fish_speech.models.text2semantic.tp import maybe_init_dist
  19. from fish_speech.text import g2p
  20. from fish_speech.text.symbols import pad as pad_symbol
  21. from fish_speech.text.symbols import pu_symbols
  22. def multinomial_sample_one_no_sync(
  23. probs_sort,
  24. ): # Does multinomial sampling without a cuda synchronization
  25. q = torch.empty_like(probs_sort).exponential_(1)
  26. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  27. def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
  28. logits = logits / max(temperature, 1e-5)
  29. if top_k is not None:
  30. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  31. pivot = v.select(-1, -1).unsqueeze(-1)
  32. logits = torch.where(logits < pivot, -float("Inf"), logits)
  33. probs = torch.nn.functional.softmax(logits, dim=-1)
  34. return probs
  35. def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
  36. probs = logits_to_probs(logits[0, -1], temperature, top_k)
  37. idx_next = multinomial_sample_one_no_sync(probs)
  38. return idx_next, probs
  39. def decode_token(
  40. model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
  41. ) -> torch.Tensor:
  42. # input_pos: [B, S]
  43. logits = model.forward_generate(x, input_pos)
  44. codebooks = [sample(logits.token_logits, **sampling_kwargs)[0]]
  45. # Disable <s> and </s> tokens for 2-n codebooks
  46. logits.codebook_logits[:, :, 1:, :2] = -float("Inf")
  47. for i in range(model.config.num_codebooks):
  48. codebooks.append(sample(logits.codebook_logits[:, :, i], **sampling_kwargs)[0])
  49. return torch.stack(codebooks, dim=0)
  50. def decode_n_tokens(
  51. model: Transformer,
  52. cur_token: torch.Tensor,
  53. input_pos: torch.Tensor,
  54. num_new_tokens: int,
  55. callback=lambda _: _,
  56. **sampling_kwargs,
  57. ):
  58. new_tokens = []
  59. for i in range(num_new_tokens):
  60. with torch.backends.cuda.sdp_kernel(
  61. enable_flash=False, enable_mem_efficient=False, enable_math=True
  62. ): # Actually better for Inductor to codegen attention here
  63. next_token = decode_token(model, cur_token, input_pos, **sampling_kwargs)
  64. input_pos += 1
  65. new_tokens.append(next_token.clone())
  66. callback(new_tokens[-1])
  67. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  68. # TODO: use tokenizer
  69. if (cur_token[0, 1:, 0] == 1).any():
  70. print("EOS detected, stopping generation")
  71. break
  72. return new_tokens
  73. def model_forward(model, x, input_pos):
  74. return model(x, input_pos)
  75. @torch.no_grad()
  76. def generate(
  77. model: Transformer,
  78. prompt: torch.Tensor,
  79. max_new_tokens: int,
  80. *,
  81. interactive: bool,
  82. callback=lambda x: x,
  83. **sampling_kwargs,
  84. ) -> torch.Tensor:
  85. """
  86. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  87. """
  88. # create an empty tensor of the expected final shape and fill in the current tokens
  89. T = prompt.size(1)
  90. T_new = T + max_new_tokens
  91. if interactive:
  92. max_seq_length = 350
  93. else:
  94. max_seq_length = min(T_new, model.config.max_seq_len)
  95. device, dtype = prompt.device, prompt.dtype
  96. max_seq_length = max_seq_length
  97. with torch.device(device):
  98. model.setup_caches(max_batch_size=1, max_seq_len=max_seq_length)
  99. codebook_dim = 1 + model.config.num_codebooks
  100. # create an empty tensor of the expected final shape and fill in the current tokens
  101. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  102. empty[:, :T] = prompt
  103. seq = empty
  104. input_pos = torch.arange(0, T, device=device)
  105. next_token = decode_token(
  106. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  107. )
  108. seq[:, T : T + 1] = next_token
  109. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  110. generated_tokens = decode_n_tokens(
  111. model,
  112. next_token.view(1, codebook_dim, -1),
  113. input_pos,
  114. max_new_tokens - 1,
  115. callback=callback,
  116. **sampling_kwargs,
  117. )
  118. x = torch.cat(generated_tokens, dim=1)
  119. seq = seq[:, : T + 1 + x.size(1)]
  120. seq[:, T + 1 :] = x
  121. return seq
  122. def encode_tokens(tokenizer, string, bos=True, device="cuda"):
  123. tokens = tokenizer.encode(
  124. string, max_length=10**6, add_special_tokens=bos, truncation=False
  125. )
  126. tokens = torch.tensor([tokens], dtype=torch.int, device=device)
  127. # Codebooks
  128. zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
  129. return torch.cat((tokens, zeros), dim=0)
  130. def _load_model(checkpoint_path, device, precision, use_tp):
  131. with torch.device("meta"):
  132. # TODO: support different model archs
  133. model = Transformer(
  134. ModelArgs(
  135. max_seq_len=4096,
  136. vocab_size=32312,
  137. n_layer=24,
  138. n_head=16,
  139. dim=1024,
  140. rope_base=10000,
  141. norm_eps=1e-5,
  142. codebook_size=168,
  143. num_codebooks=4,
  144. )
  145. )
  146. if "int8" in str(checkpoint_path):
  147. print("Using int8 weight-only quantization!")
  148. from quantize import WeightOnlyInt8QuantHandler
  149. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  150. model = simple_quantizer.convert_for_runtime()
  151. if "int4" in str(checkpoint_path):
  152. print("Using int4 quantization!")
  153. path_comps = checkpoint_path.name.split(".")
  154. assert path_comps[-2].startswith("g")
  155. groupsize = int(path_comps[-2][1:])
  156. from quantize import WeightOnlyInt4QuantHandler
  157. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  158. model = simple_quantizer.convert_for_runtime()
  159. checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
  160. model.load_state_dict(checkpoint, assign=True)
  161. if use_tp:
  162. from tp import apply_tp
  163. print("Applying tensor parallel to model ...")
  164. apply_tp(model)
  165. model = model.to(device=device, dtype=precision)
  166. return model.eval()
  167. B_INST, E_INST = "[INST]", "[/INST]"
  168. def main(
  169. prompt: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  170. interactive: bool = False,
  171. num_samples: int = 5,
  172. max_new_tokens: int = 100,
  173. top_k: int = 200,
  174. temperature: float = 0.8,
  175. checkpoint_path: Path = Path(
  176. "results/text2semantic_400m/checkpoints/step_000025000.ckpt"
  177. ),
  178. compile: bool = True,
  179. compile_prefill: bool = False,
  180. profile: Optional[Path] = None,
  181. tokenizer: str = "fishaudio/speech-lm-v1",
  182. ) -> None:
  183. """Generates text samples based on a pre-trained Transformer model and tokenizer."""
  184. assert checkpoint_path.is_file(), checkpoint_path
  185. global print
  186. rank = maybe_init_dist()
  187. use_tp = rank is not None
  188. if use_tp:
  189. torch.cuda.set_device(rank)
  190. if rank != 0:
  191. # only print on rank 0
  192. print = lambda *args, **kwargs: None
  193. device = "cuda"
  194. precision = torch.bfloat16
  195. print("Loading model ...")
  196. t0 = time.time()
  197. model = _load_model(checkpoint_path, device, precision, use_tp)
  198. torch.cuda.synchronize()
  199. print(f"Time to load model: {time.time() - t0:.02f} seconds")
  200. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  201. prompt = g2p(prompt)
  202. prompt = [
  203. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  204. for _, i in prompt
  205. ]
  206. prompt = " ".join(prompt)
  207. print(prompt)
  208. encoded = encode_tokens(
  209. tokenizer, f"[INST] {prompt} [/INST]", bos=True, device=device
  210. )
  211. print(encoded[0])
  212. prompt_length = encoded.size(1)
  213. torch.manual_seed(1234)
  214. model_size = sum(
  215. [
  216. p.numel() * p.dtype.itemsize
  217. for p in itertools.chain(model.parameters(), model.buffers())
  218. ]
  219. )
  220. if compile:
  221. global decode_token
  222. decode_token = torch.compile(
  223. decode_token, mode="reduce-overhead", fullgraph=True
  224. )
  225. aggregate_metrics = {
  226. "tokens_per_sec": [],
  227. }
  228. start = -1 if compile else 0
  229. for i in range(start, num_samples):
  230. torch.cuda.synchronize()
  231. if i >= 0 and interactive:
  232. prompt = input("What is your prompt? ")
  233. encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
  234. if interactive and i >= 0:
  235. buffer = []
  236. period_id = tokenizer.encode(".")[0]
  237. done_generating = False
  238. def callback(x):
  239. nonlocal done_generating
  240. if done_generating:
  241. return
  242. buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
  243. if x.item() == tokenizer.eos_id():
  244. done_generating = True
  245. if len(buffer) == 4 or done_generating:
  246. print("".join(buffer), end="", flush=True)
  247. buffer.clear()
  248. # print(, end='', flush=True)
  249. else:
  250. callback = lambda x: x
  251. t0 = time.perf_counter()
  252. import contextlib
  253. if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
  254. prof = contextlib.nullcontext()
  255. else:
  256. torch.profiler._utils._init_for_cuda_graphs()
  257. prof = torch.profiler.profile()
  258. with prof:
  259. y = generate(
  260. model,
  261. encoded,
  262. max_new_tokens,
  263. interactive=interactive,
  264. callback=callback,
  265. temperature=temperature,
  266. top_k=top_k,
  267. )
  268. if i == -1:
  269. print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  270. continue
  271. if hasattr(prof, "export_chrome_trace"):
  272. if use_tp:
  273. prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
  274. else:
  275. prof.export_chrome_trace(f"{profile}.json")
  276. torch.cuda.synchronize()
  277. t = time.perf_counter() - t0
  278. if not interactive:
  279. print(tokenizer.decode(y[0].tolist()))
  280. codes = y[1:, prompt_length:-1] - 2
  281. assert (codes >= 0).all()
  282. import numpy as np
  283. np.save(f"codes_{i}.npy", codes.cpu().numpy())
  284. else:
  285. print()
  286. tokens_generated = y.size(1) - prompt_length
  287. tokens_sec = tokens_generated / t
  288. aggregate_metrics["tokens_per_sec"].append(tokens_sec)
  289. print(
  290. f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
  291. )
  292. print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
  293. print("==========")
  294. print(
  295. f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}"
  296. )
  297. print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
  298. if __name__ == "__main__":
  299. import argparse
  300. parser = argparse.ArgumentParser(description="Your CLI description.")
  301. parser.add_argument(
  302. "--prompt",
  303. type=str,
  304. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  305. help="Input prompt.",
  306. )
  307. parser.add_argument(
  308. "--interactive",
  309. action="store_true",
  310. help="Whether to launch in interactive mode",
  311. )
  312. parser.add_argument("--num_samples", type=int, default=1, help="Number of samples.")
  313. parser.add_argument(
  314. "--max_new_tokens", type=int, default=768, help="Maximum number of new tokens."
  315. )
  316. parser.add_argument("--top_k", type=int, default=10, help="Top-k for sampling.")
  317. parser.add_argument(
  318. "--temperature", type=float, default=1.0, help="Temperature for sampling."
  319. )
  320. parser.add_argument(
  321. "--checkpoint_path",
  322. type=Path,
  323. default=Path("results/text2semantic_400m/step_000025000_weights.ckpt"),
  324. help="Model checkpoint path.",
  325. )
  326. parser.add_argument(
  327. "--compile", action="store_true", help="Whether to compile the model."
  328. )
  329. parser.add_argument("--profile", type=Path, default=None, help="Profile path.")
  330. args = parser.parse_args()
  331. main(
  332. args.prompt,
  333. args.interactive,
  334. args.num_samples,
  335. args.max_new_tokens,
  336. args.top_k,
  337. args.temperature,
  338. args.checkpoint_path,
  339. args.compile,
  340. args.profile,
  341. )