generate.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. import os
  2. import time
  3. from pathlib import Path
  4. from typing import Optional, Tuple
  5. import click
  6. import numpy as np
  7. import torch
  8. import torch._dynamo.config
  9. import torch._inductor.config
  10. from hydra import compose, initialize
  11. from hydra.utils import instantiate
  12. from loguru import logger
  13. from tqdm import tqdm
  14. from transformers import AutoTokenizer
  15. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  16. torch._inductor.config.coordinate_descent_tuning = True
  17. torch._inductor.config.triton.unique_kernel_names = True
  18. torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
  19. from fish_speech.models.text2semantic.llama import Transformer
  20. from fish_speech.text import g2p
  21. from fish_speech.text.symbols import pad as pad_symbol
  22. from fish_speech.text.symbols import pu_symbols
  23. def multinomial_sample_one_no_sync(
  24. probs_sort,
  25. ): # Does multinomial sampling without a cuda synchronization
  26. q = torch.empty_like(probs_sort).exponential_(1)
  27. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  28. def logits_to_probs(
  29. logits,
  30. previous_tokens: Optional[torch.Tensor] = None,
  31. temperature: float = 1.0,
  32. top_k: Optional[int] = None,
  33. top_p: Optional[int] = None,
  34. repetition_penalty: float = 1.0,
  35. ):
  36. if previous_tokens is not None and repetition_penalty != 1.0:
  37. previous_tokens = previous_tokens.long()
  38. score = torch.gather(logits, dim=-1, index=previous_tokens)
  39. score = torch.where(
  40. score < 0, score * repetition_penalty, score / repetition_penalty
  41. )
  42. logits.scatter_(dim=-1, index=previous_tokens, src=score)
  43. if top_p is not None and top_p < 1.0:
  44. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  45. cum_probs = torch.cumsum(
  46. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  47. )
  48. sorted_indices_to_remove = cum_probs > top_p
  49. sorted_indices_to_remove[0] = False # keep at least one option
  50. indices_to_remove = sorted_indices_to_remove.scatter(
  51. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  52. )
  53. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  54. logits = logits / max(temperature, 1e-5)
  55. if top_k is not None:
  56. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  57. pivot = v.select(-1, -1).unsqueeze(-1)
  58. logits = torch.where(logits < pivot, -float("Inf"), logits)
  59. probs = torch.nn.functional.softmax(logits, dim=-1)
  60. return probs
  61. def sample(
  62. logits,
  63. previous_tokens: Optional[torch.Tensor] = None,
  64. **sampling_kwargs,
  65. ) -> Tuple[torch.Tensor, torch.Tensor]:
  66. probs = logits_to_probs(
  67. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  68. )
  69. idx_next = multinomial_sample_one_no_sync(probs)
  70. return idx_next, probs
  71. def decode_one_token(
  72. model: Transformer,
  73. x: torch.Tensor,
  74. input_pos: torch.Tensor,
  75. previous_tokens: torch.Tensor = None,
  76. **sampling_kwargs,
  77. ) -> torch.Tensor:
  78. assert input_pos.shape[-1] == 1
  79. logits = model.forward_generate(x, input_pos)
  80. codebooks = [
  81. sample(
  82. logits.token_logits,
  83. previous_tokens=previous_tokens[0],
  84. **sampling_kwargs,
  85. )[0]
  86. ]
  87. # Disable <s> and </s> tokens for codebooks
  88. if model.config.num_codebooks != 0:
  89. logits.codebook_logits[:, :, :, :2] = -float("Inf")
  90. for i in range(model.config.num_codebooks):
  91. codebooks.append(
  92. sample(
  93. logits.codebook_logits[:, :, i],
  94. previous_tokens=previous_tokens[i + 1],
  95. **sampling_kwargs,
  96. )[0]
  97. )
  98. return torch.stack(codebooks, dim=0)
  99. def prefill(
  100. model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
  101. ) -> torch.Tensor:
  102. # input_pos: [B, S]
  103. logits = model.forward_generate(x, input_pos)
  104. codebooks = [
  105. sample(
  106. logits.token_logits,
  107. previous_tokens=None,
  108. **sampling_kwargs,
  109. )[0]
  110. ]
  111. # Disable <s> and </s> tokens for codebooks
  112. if model.config.num_codebooks != 0:
  113. logits.codebook_logits[:, :, :, :2] = -float("Inf")
  114. for i in range(model.config.num_codebooks):
  115. codebooks.append(
  116. sample(
  117. logits.codebook_logits[:, :, i],
  118. previous_tokens=None,
  119. **sampling_kwargs,
  120. )[0]
  121. )
  122. return torch.stack(codebooks, dim=0)
  123. def decode_n_tokens(
  124. model: Transformer,
  125. cur_token: torch.Tensor,
  126. input_pos: torch.Tensor,
  127. num_new_tokens: int,
  128. eos_token_id: int = 2,
  129. **sampling_kwargs,
  130. ):
  131. previous_tokens = torch.zeros(
  132. (model.config.num_codebooks + 1, num_new_tokens),
  133. dtype=torch.int,
  134. device=cur_token.device,
  135. )
  136. for i in tqdm(range(num_new_tokens)):
  137. with torch.backends.cuda.sdp_kernel(
  138. enable_flash=False, enable_mem_efficient=False, enable_math=True
  139. ): # Actually better for Inductor to codegen attention here
  140. next_token = decode_one_token(
  141. model,
  142. cur_token,
  143. input_pos,
  144. previous_tokens,
  145. **sampling_kwargs,
  146. )
  147. input_pos += 1
  148. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  149. previous_tokens[:, i : i + 1] = next_token.view(
  150. model.config.num_codebooks + 1, -1
  151. )
  152. # TODO: use tokenizer's eos
  153. if (cur_token[0, 0, -1] == eos_token_id).any():
  154. break
  155. return previous_tokens[:, : i + 1]
  156. @torch.no_grad()
  157. def generate(
  158. *,
  159. model: Transformer,
  160. prompt: torch.Tensor,
  161. max_new_tokens: int,
  162. eos_token_id: int = 2,
  163. **sampling_kwargs,
  164. ) -> torch.Tensor:
  165. """
  166. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  167. """
  168. # create an empty tensor of the expected final shape and fill in the current tokens
  169. T = prompt.size(1)
  170. if max_new_tokens:
  171. if T + max_new_tokens > model.config.max_seq_len:
  172. max_new_tokens = model.config.max_seq_len - T
  173. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  174. T_new = T + max_new_tokens
  175. else:
  176. T_new = model.config.max_seq_len
  177. max_new_tokens = T_new - T
  178. device, dtype = prompt.device, prompt.dtype
  179. with torch.device(device):
  180. model.setup_caches(max_batch_size=1, max_seq_len=T_new)
  181. codebook_dim = 1 + model.config.num_codebooks
  182. # create an empty tensor of the expected final shape and fill in the current tokens
  183. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  184. empty[:, :T] = prompt
  185. seq = empty
  186. input_pos = torch.arange(0, T, device=device)
  187. next_token = prefill(
  188. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  189. )
  190. seq[:, T : T + 1] = next_token
  191. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  192. x = decode_n_tokens(
  193. model,
  194. next_token.view(1, codebook_dim, -1),
  195. input_pos,
  196. max_new_tokens - 1,
  197. eos_token_id=eos_token_id,
  198. **sampling_kwargs,
  199. )
  200. # x = torch.cat(generated_tokens, dim=1)
  201. seq = seq[:, : T + 1 + x.size(1)]
  202. seq[:, T + 1 :] = x
  203. return seq
  204. def encode_tokens(
  205. tokenizer, string, bos=True, device="cuda", prompt_string=None, prompt_tokens=None
  206. ):
  207. if prompt_string is not None:
  208. string = prompt_string + " " + string
  209. prompt = g2p(string)
  210. prompt = [
  211. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  212. for _, i in prompt
  213. ]
  214. prompt = " ".join(prompt)
  215. string = f"[INST] {prompt} [/INST]"
  216. tokens = tokenizer.encode(
  217. string,
  218. max_length=10**6,
  219. add_special_tokens=bos,
  220. truncation=False,
  221. )
  222. tokens = torch.tensor([tokens], dtype=torch.int, device=device)
  223. # Codebooks
  224. zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
  225. prompt = torch.cat((tokens, zeros), dim=0)
  226. if prompt_tokens is None:
  227. return prompt
  228. # Get prompt tokens
  229. assert prompt_tokens.ndim == 2
  230. data = prompt_tokens + 2
  231. zeros = (
  232. torch.zeros((1, data.size(1)), dtype=torch.int, device=device)
  233. + tokenizer.pad_token_id
  234. ) # 32311 is the <pad> token
  235. data = torch.cat((zeros, data), dim=0)
  236. prompt = torch.cat((prompt, data), dim=1)
  237. return prompt
  238. def load_model(config_name, checkpoint_path, device, precision):
  239. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  240. cfg = compose(config_name=config_name)
  241. with torch.device("meta"):
  242. model: Transformer = instantiate(cfg.model.model)
  243. if "int8" in str(checkpoint_path):
  244. logger.info("Using int8 weight-only quantization!")
  245. from quantize import WeightOnlyInt8QuantHandler
  246. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  247. model = simple_quantizer.convert_for_runtime()
  248. if "int4" in str(checkpoint_path):
  249. logger.info("Using int4 quantization!")
  250. path_comps = checkpoint_path.name.split(".")
  251. assert path_comps[-2].startswith("g")
  252. groupsize = int(path_comps[-2][1:])
  253. from quantize import WeightOnlyInt4QuantHandler
  254. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  255. model = simple_quantizer.convert_for_runtime()
  256. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  257. if "state_dict" in checkpoint:
  258. checkpoint = checkpoint["state_dict"]
  259. if any(k.startswith("model.") for k in checkpoint):
  260. checkpoint = {
  261. k.replace("model.", ""): v
  262. for k, v in checkpoint.items()
  263. if k.startswith("model.")
  264. }
  265. model.load_state_dict(checkpoint, assign=True)
  266. model = model.to(device=device, dtype=precision)
  267. logger.info("Restored model from checkpoint")
  268. return model.eval()
  269. @click.command()
  270. @click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
  271. @click.option("--prompt-string", type=str, default=None)
  272. @click.option(
  273. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  274. )
  275. @click.option("--num-samples", type=int, default=1)
  276. @click.option("--max_new_tokens", type=int, default=0)
  277. @click.option("--top_k", type=int, default=50)
  278. @click.option("--top_p", type=float, default=0.95)
  279. @click.option("--repetition-penalty", type=float, default=1.05)
  280. @click.option("--temperature", type=float, default=0.8)
  281. @click.option(
  282. "--checkpoint-path",
  283. type=click.Path(path_type=Path, exists=True),
  284. default="results/text2semantic_400m_finetune/step_000002000.pth",
  285. )
  286. @click.option("--config-name", type=str, default="text2semantic_finetune")
  287. @click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
  288. @click.option("--compile/--no-compile", default=False)
  289. @click.option("--seed", type=int, default=42)
  290. def main(
  291. text: str,
  292. prompt_string: Optional[str],
  293. prompt_tokens: Optional[Path],
  294. num_samples: int,
  295. max_new_tokens: int,
  296. top_k: int,
  297. top_p: int,
  298. repetition_penalty: float,
  299. temperature: float,
  300. checkpoint_path: Path,
  301. config_name: str,
  302. tokenizer: str,
  303. compile: bool,
  304. seed: int,
  305. ) -> None:
  306. device = "cuda"
  307. precision = torch.bfloat16
  308. logger.info("Loading model ...")
  309. t0 = time.time()
  310. model = load_model(config_name, checkpoint_path, device, precision)
  311. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  312. torch.cuda.synchronize()
  313. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  314. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  315. prompt_tokens = (
  316. torch.from_numpy(np.load(prompt_tokens)).to(device)
  317. if prompt_tokens is not None
  318. else None
  319. )
  320. encoded = encode_tokens(
  321. tokenizer,
  322. text,
  323. prompt_string=prompt_string,
  324. prompt_tokens=prompt_tokens,
  325. bos=True,
  326. device=device,
  327. )
  328. prompt_length = encoded.size(1)
  329. logger.info(f"Encoded prompt shape: {encoded.shape}")
  330. torch.manual_seed(seed)
  331. if compile:
  332. global decode_one_token
  333. decode_one_token = torch.compile(
  334. decode_one_token, mode="reduce-overhead", fullgraph=True
  335. )
  336. for i in range(num_samples):
  337. torch.cuda.synchronize()
  338. t0 = time.perf_counter()
  339. y = generate(
  340. model=model,
  341. prompt=encoded,
  342. max_new_tokens=max_new_tokens,
  343. eos_token_id=tokenizer.eos_token_id,
  344. temperature=temperature,
  345. top_k=top_k,
  346. top_p=top_p,
  347. repetition_penalty=repetition_penalty,
  348. )
  349. if i == 0 and compile:
  350. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  351. torch.cuda.synchronize()
  352. t = time.perf_counter() - t0
  353. tokens_generated = y.size(1) - prompt_length
  354. tokens_sec = tokens_generated / t
  355. logger.info(
  356. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  357. )
  358. logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
  359. logger.info(
  360. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  361. )
  362. codes = y[1:, prompt_length:-1]
  363. codes = codes - 2
  364. assert (codes >= 0).all(), "Codes should be >= 0"
  365. np.save(f"codes_{i}.npy", codes.cpu().numpy())
  366. logger.info(f"Saved codes to codes_{i}.npy")
  367. if __name__ == "__main__":
  368. main()