generate.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. import os
  2. import time
  3. from pathlib import Path
  4. from typing import Optional, Tuple
  5. import click
  6. import numpy as np
  7. import torch
  8. import torch._dynamo.config
  9. import torch._inductor.config
  10. from hydra import compose, initialize
  11. from hydra.utils import instantiate
  12. from loguru import logger
  13. from tqdm import tqdm
  14. from transformers import AutoTokenizer
  15. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  16. torch._inductor.config.coordinate_descent_tuning = True
  17. torch._inductor.config.triton.unique_kernel_names = True
  18. torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
  19. from fish_speech.models.text2semantic.llama import Transformer
  20. from fish_speech.text import g2p
  21. from fish_speech.text.symbols import pad as pad_symbol
  22. from fish_speech.text.symbols import pu_symbols
  23. def multinomial_sample_one_no_sync(
  24. probs_sort,
  25. ): # Does multinomial sampling without a cuda synchronization
  26. q = torch.empty_like(probs_sort).exponential_(1)
  27. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  28. def logits_to_probs(
  29. logits,
  30. previous_tokens: Optional[torch.Tensor] = None,
  31. temperature: float = 1.0,
  32. top_k: Optional[int] = None,
  33. top_p: Optional[int] = None,
  34. repetition_penalty: float = 1.0,
  35. ):
  36. if previous_tokens is not None and repetition_penalty != 1.0:
  37. previous_tokens = previous_tokens.long()
  38. score = torch.gather(logits, dim=0, index=previous_tokens)
  39. score = torch.where(
  40. score < 0, score * repetition_penalty, score / repetition_penalty
  41. )
  42. logits.scatter_(dim=0, index=previous_tokens, src=score)
  43. if top_p is not None and top_p < 1.0:
  44. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  45. cum_probs = torch.cumsum(
  46. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  47. )
  48. sorted_indices_to_remove = cum_probs > top_p
  49. sorted_indices_to_remove[0] = False # keep at least one option
  50. indices_to_remove = sorted_indices_to_remove.scatter(
  51. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  52. )
  53. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  54. logits = logits / max(temperature, 1e-5)
  55. if top_k is not None:
  56. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  57. pivot = v.select(-1, -1).unsqueeze(-1)
  58. logits = torch.where(logits < pivot, -float("Inf"), logits)
  59. probs = torch.nn.functional.softmax(logits, dim=-1)
  60. return probs
  61. def sample(
  62. logits,
  63. previous_tokens: Optional[torch.Tensor] = None,
  64. **sampling_kwargs,
  65. ) -> Tuple[torch.Tensor, torch.Tensor]:
  66. probs = logits_to_probs(
  67. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  68. )
  69. idx_next = multinomial_sample_one_no_sync(probs)
  70. return idx_next, probs
  71. def decode_one_token(
  72. model: Transformer,
  73. x: torch.Tensor,
  74. input_pos: torch.Tensor,
  75. previous_tokens: torch.Tensor = None,
  76. **sampling_kwargs,
  77. ) -> torch.Tensor:
  78. assert input_pos.shape[-1] == 1
  79. logits = model.forward_generate(x, input_pos)
  80. codebooks = [
  81. sample(
  82. logits.token_logits,
  83. previous_tokens=None, # Disable repetition penalty for the token codebook
  84. **sampling_kwargs,
  85. )[0]
  86. ]
  87. # Disable <s> and </s> tokens for codebooks
  88. if model.config.num_codebooks != 0:
  89. logits.codebook_logits[:, :, :, :2] = -float("Inf")
  90. for i in range(model.config.num_codebooks):
  91. codebooks.append(
  92. sample(
  93. logits.codebook_logits[:, :, i],
  94. previous_tokens=previous_tokens[i + 1]
  95. if previous_tokens is not None
  96. else None,
  97. **sampling_kwargs,
  98. )[0]
  99. )
  100. return torch.stack(codebooks, dim=0)
  101. def prefill(
  102. model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
  103. ) -> torch.Tensor:
  104. # input_pos: [B, S]
  105. logits = model.forward_generate(x, input_pos)
  106. codebooks = [
  107. sample(
  108. logits.token_logits,
  109. previous_tokens=None,
  110. **sampling_kwargs,
  111. )[0]
  112. ]
  113. # Disable <s> and </s> tokens for codebooks
  114. if model.config.num_codebooks != 0:
  115. logits.codebook_logits[:, :, :, :2] = -float("Inf")
  116. for i in range(model.config.num_codebooks):
  117. codebooks.append(
  118. sample(
  119. logits.codebook_logits[:, :, i],
  120. previous_tokens=None,
  121. **sampling_kwargs,
  122. )[0]
  123. )
  124. return torch.stack(codebooks, dim=0)
  125. def decode_n_tokens(
  126. model: Transformer,
  127. cur_token: torch.Tensor,
  128. input_pos: torch.Tensor,
  129. num_new_tokens: int,
  130. eos_token_id: int = 2,
  131. **sampling_kwargs,
  132. ):
  133. previous_tokens = torch.zeros(
  134. (model.config.num_codebooks + 1, num_new_tokens),
  135. dtype=torch.int,
  136. device=cur_token.device,
  137. )
  138. for i in tqdm(range(num_new_tokens)):
  139. # We need to get windowed repeat penalty
  140. win_size = 16
  141. if i < win_size:
  142. window = previous_tokens[:, :win_size]
  143. else:
  144. window = previous_tokens[:, i - win_size : i]
  145. with torch.backends.cuda.sdp_kernel(
  146. enable_flash=False, enable_mem_efficient=False, enable_math=True
  147. ): # Actually better for Inductor to codegen attention here
  148. next_token = decode_one_token(
  149. model,
  150. cur_token,
  151. input_pos,
  152. window,
  153. **sampling_kwargs,
  154. )
  155. input_pos += 1
  156. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  157. previous_tokens[:, i : i + 1] = next_token.view(
  158. model.config.num_codebooks + 1, -1
  159. )
  160. # TODO: use tokenizer's eos
  161. if (cur_token[0, 0, -1] == eos_token_id).any():
  162. break
  163. return previous_tokens[:, : i + 1]
  164. @torch.no_grad()
  165. def generate(
  166. *,
  167. model: Transformer,
  168. prompt: torch.Tensor,
  169. max_new_tokens: int,
  170. eos_token_id: int = 2,
  171. **sampling_kwargs,
  172. ) -> torch.Tensor:
  173. """
  174. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  175. """
  176. # create an empty tensor of the expected final shape and fill in the current tokens
  177. T = prompt.size(1)
  178. if max_new_tokens:
  179. if T + max_new_tokens > model.config.max_seq_len:
  180. max_new_tokens = model.config.max_seq_len - T
  181. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  182. T_new = T + max_new_tokens
  183. else:
  184. T_new = model.config.max_seq_len
  185. max_new_tokens = T_new - T
  186. device, dtype = prompt.device, prompt.dtype
  187. with torch.device(device):
  188. model.setup_caches(max_batch_size=1, max_seq_len=T_new)
  189. codebook_dim = 1 + model.config.num_codebooks
  190. # create an empty tensor of the expected final shape and fill in the current tokens
  191. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  192. empty[:, :T] = prompt
  193. seq = empty
  194. input_pos = torch.arange(0, T, device=device)
  195. next_token = prefill(
  196. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  197. )
  198. seq[:, T : T + 1] = next_token
  199. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  200. x = decode_n_tokens(
  201. model,
  202. next_token.view(1, codebook_dim, -1),
  203. input_pos,
  204. max_new_tokens - 1,
  205. eos_token_id=eos_token_id,
  206. **sampling_kwargs,
  207. )
  208. # x = torch.cat(generated_tokens, dim=1)
  209. seq = seq[:, : T + 1 + x.size(1)]
  210. seq[:, T + 1 :] = x
  211. return seq
  212. def encode_tokens(
  213. tokenizer, string, bos=True, device="cuda", prompt_string=None, prompt_tokens=None
  214. ):
  215. if prompt_string is not None:
  216. string = prompt_string + " " + string
  217. prompt = g2p(string)
  218. prompt = [
  219. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  220. for _, i in prompt
  221. ]
  222. prompt = " ".join(prompt)
  223. string = f"[INST] {prompt} [/INST]"
  224. tokens = tokenizer.encode(
  225. string,
  226. max_length=10**6,
  227. add_special_tokens=bos,
  228. truncation=False,
  229. )
  230. tokens = torch.tensor([tokens], dtype=torch.int, device=device)
  231. # Codebooks
  232. zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
  233. prompt = torch.cat((tokens, zeros), dim=0)
  234. if prompt_tokens is None:
  235. return prompt
  236. # Get prompt tokens
  237. assert prompt_tokens.ndim == 2
  238. data = prompt_tokens + 2
  239. zeros = (
  240. torch.zeros((1, data.size(1)), dtype=torch.int, device=device)
  241. + tokenizer.pad_token_id
  242. ) # 32311 is the <pad> token
  243. data = torch.cat((zeros, data), dim=0)
  244. prompt = torch.cat((prompt, data), dim=1)
  245. return prompt
  246. def load_model(config_name, checkpoint_path, device, precision):
  247. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  248. cfg = compose(config_name=config_name)
  249. with torch.device("meta"):
  250. model: Transformer = instantiate(cfg.model.model)
  251. if "int8" in str(checkpoint_path):
  252. logger.info("Using int8 weight-only quantization!")
  253. from quantize import WeightOnlyInt8QuantHandler
  254. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  255. model = simple_quantizer.convert_for_runtime()
  256. if "int4" in str(checkpoint_path):
  257. logger.info("Using int4 quantization!")
  258. path_comps = checkpoint_path.name.split(".")
  259. assert path_comps[-2].startswith("g")
  260. groupsize = int(path_comps[-2][1:])
  261. from quantize import WeightOnlyInt4QuantHandler
  262. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  263. model = simple_quantizer.convert_for_runtime()
  264. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  265. if "state_dict" in checkpoint:
  266. checkpoint = checkpoint["state_dict"]
  267. if any(k.startswith("model.") for k in checkpoint):
  268. checkpoint = {
  269. k.replace("model.", ""): v
  270. for k, v in checkpoint.items()
  271. if k.startswith("model.")
  272. }
  273. model.load_state_dict(checkpoint, assign=True)
  274. model = model.to(device=device, dtype=precision)
  275. logger.info("Restored model from checkpoint")
  276. return model.eval()
  277. @click.command()
  278. @click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
  279. @click.option("--prompt-string", type=str, default=None)
  280. @click.option(
  281. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  282. )
  283. @click.option("--num-samples", type=int, default=1)
  284. @click.option("--max_new_tokens", type=int, default=0)
  285. @click.option("--top-k", type=int, default=None)
  286. @click.option("--top-p", type=float, default=0.5)
  287. @click.option("--repetition-penalty", type=float, default=1.5)
  288. @click.option("--temperature", type=float, default=0.7)
  289. @click.option(
  290. "--checkpoint-path",
  291. type=click.Path(path_type=Path, exists=True),
  292. default="results/text2semantic_400m_finetune/step_000002000.pth",
  293. )
  294. @click.option("--config-name", type=str, default="text2semantic_finetune")
  295. @click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
  296. @click.option("--compile/--no-compile", default=False)
  297. @click.option("--seed", type=int, default=42)
  298. def main(
  299. text: str,
  300. prompt_string: Optional[str],
  301. prompt_tokens: Optional[Path],
  302. num_samples: int,
  303. max_new_tokens: int,
  304. top_k: int,
  305. top_p: int,
  306. repetition_penalty: float,
  307. temperature: float,
  308. checkpoint_path: Path,
  309. config_name: str,
  310. tokenizer: str,
  311. compile: bool,
  312. seed: int,
  313. ) -> None:
  314. device = "cuda"
  315. precision = torch.bfloat16
  316. logger.info("Loading model ...")
  317. t0 = time.time()
  318. model = load_model(config_name, checkpoint_path, device, precision)
  319. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  320. torch.cuda.synchronize()
  321. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  322. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  323. prompt_tokens = (
  324. torch.from_numpy(np.load(prompt_tokens)).to(device)
  325. if prompt_tokens is not None
  326. else None
  327. )
  328. encoded = encode_tokens(
  329. tokenizer,
  330. text,
  331. prompt_string=prompt_string,
  332. prompt_tokens=prompt_tokens,
  333. bos=True,
  334. device=device,
  335. )
  336. prompt_length = encoded.size(1)
  337. logger.info(f"Encoded prompt shape: {encoded.shape}")
  338. torch.manual_seed(seed)
  339. if compile:
  340. global decode_one_token
  341. decode_one_token = torch.compile(
  342. decode_one_token, mode="reduce-overhead", fullgraph=True
  343. )
  344. for i in range(num_samples):
  345. torch.cuda.synchronize()
  346. t0 = time.perf_counter()
  347. y = generate(
  348. model=model,
  349. prompt=encoded,
  350. max_new_tokens=max_new_tokens,
  351. eos_token_id=tokenizer.eos_token_id,
  352. temperature=temperature,
  353. top_k=top_k,
  354. top_p=top_p,
  355. repetition_penalty=repetition_penalty,
  356. )
  357. if i == 0 and compile:
  358. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  359. torch.cuda.synchronize()
  360. t = time.perf_counter() - t0
  361. tokens_generated = y.size(1) - prompt_length
  362. tokens_sec = tokens_generated / t
  363. logger.info(
  364. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  365. )
  366. logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
  367. logger.info(
  368. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  369. )
  370. codes = y[1:, prompt_length:-1]
  371. codes = codes - 2
  372. assert (codes >= 0).all(), "Codes should be >= 0"
  373. np.save(f"codes_{i}.npy", codes.cpu().numpy())
  374. logger.info(f"Saved codes to codes_{i}.npy")
  375. if __name__ == "__main__":
  376. main()