generate.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. import os
  2. import time
  3. from pathlib import Path
  4. from typing import Optional, Tuple
  5. import click
  6. import numpy as np
  7. import torch
  8. import torch._dynamo.config
  9. import torch._inductor.config
  10. from hydra import compose, initialize
  11. from hydra.utils import instantiate
  12. from loguru import logger
  13. from tqdm import tqdm
  14. from transformers import AutoTokenizer
  15. from fish_speech.text.parser import clean_text
  16. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  17. torch._inductor.config.coordinate_descent_tuning = True
  18. torch._inductor.config.triton.unique_kernel_names = True
  19. if hasattr(torch._inductor.config, "fx_graph_cache"):
  20. # Experimental feature to reduce compilation times, will be on by default in future
  21. torch._inductor.config.fx_graph_cache = True
  22. from fish_speech.models.text2semantic.llama import Transformer
  23. from fish_speech.text import g2p
  24. from fish_speech.text.symbols import pad as pad_symbol
  25. from fish_speech.text.symbols import pu_symbols
  26. def multinomial_sample_one_no_sync(
  27. probs_sort,
  28. ): # Does multinomial sampling without a cuda synchronization
  29. q = torch.empty_like(probs_sort).exponential_(1)
  30. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  31. def logits_to_probs(
  32. logits,
  33. previous_tokens: Optional[torch.Tensor] = None,
  34. temperature: float = 1.0,
  35. top_k: Optional[int] = None,
  36. top_p: Optional[int] = None,
  37. repetition_penalty: float = 1.0,
  38. ):
  39. if previous_tokens is not None and repetition_penalty != 1.0:
  40. previous_tokens = previous_tokens.long()
  41. score = torch.gather(logits, dim=0, index=previous_tokens)
  42. score = torch.where(
  43. score < 0, score * repetition_penalty, score / repetition_penalty
  44. )
  45. logits.scatter_(dim=0, index=previous_tokens, src=score)
  46. if top_p is not None and top_p < 1.0:
  47. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  48. cum_probs = torch.cumsum(
  49. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  50. )
  51. sorted_indices_to_remove = cum_probs > top_p
  52. sorted_indices_to_remove[0] = False # keep at least one option
  53. indices_to_remove = sorted_indices_to_remove.scatter(
  54. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  55. )
  56. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  57. logits = logits / max(temperature, 1e-5)
  58. if top_k is not None:
  59. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  60. pivot = v.select(-1, -1).unsqueeze(-1)
  61. logits = torch.where(logits < pivot, -float("Inf"), logits)
  62. probs = torch.nn.functional.softmax(logits, dim=-1)
  63. return probs
  64. def sample(
  65. logits,
  66. previous_tokens: Optional[torch.Tensor] = None,
  67. **sampling_kwargs,
  68. ) -> Tuple[torch.Tensor, torch.Tensor]:
  69. probs = logits_to_probs(
  70. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  71. )
  72. idx_next = multinomial_sample_one_no_sync(probs)
  73. return idx_next, probs
  74. def decode_one_token(
  75. model: Transformer,
  76. x: torch.Tensor,
  77. input_pos: torch.Tensor,
  78. previous_tokens: torch.Tensor = None,
  79. **sampling_kwargs,
  80. ) -> torch.Tensor:
  81. assert input_pos.shape[-1] == 1
  82. logits = model.forward_generate(x, input_pos)
  83. codebooks = [
  84. sample(
  85. logits.token_logits,
  86. previous_tokens=None, # Disable repetition penalty for the token codebook
  87. **sampling_kwargs,
  88. )[0]
  89. ]
  90. # Disable <s> and </s> tokens for codebooks
  91. if model.config.num_codebooks != 0:
  92. logits.codebook_logits[:, :, :, :2] = -float("Inf")
  93. for i in range(model.config.num_codebooks):
  94. codebooks.append(
  95. sample(
  96. logits.codebook_logits[:, :, i],
  97. previous_tokens=previous_tokens[i + 1]
  98. if previous_tokens is not None
  99. else None,
  100. **sampling_kwargs,
  101. )[0]
  102. )
  103. return torch.stack(codebooks, dim=0)
  104. def prefill(
  105. model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
  106. ) -> torch.Tensor:
  107. # input_pos: [B, S]
  108. logits = model.forward_generate(x, input_pos)
  109. codebooks = [
  110. sample(
  111. logits.token_logits,
  112. previous_tokens=None,
  113. **sampling_kwargs,
  114. )[0]
  115. ]
  116. # Disable <s> and </s> tokens for codebooks
  117. if model.config.num_codebooks != 0:
  118. logits.codebook_logits[:, :, :, :2] = -float("Inf")
  119. for i in range(model.config.num_codebooks):
  120. codebooks.append(
  121. sample(
  122. logits.codebook_logits[:, :, i],
  123. previous_tokens=None,
  124. **sampling_kwargs,
  125. )[0]
  126. )
  127. return torch.stack(codebooks, dim=0)
  128. def decode_n_tokens(
  129. model: Transformer,
  130. cur_token: torch.Tensor,
  131. input_pos: torch.Tensor,
  132. num_new_tokens: int,
  133. eos_token_id: int = 2,
  134. **sampling_kwargs,
  135. ):
  136. previous_tokens = torch.zeros(
  137. (model.config.num_codebooks + 1, num_new_tokens),
  138. dtype=torch.int,
  139. device=cur_token.device,
  140. )
  141. for i in tqdm(range(num_new_tokens)):
  142. # We need to get windowed repeat penalty
  143. win_size = 16
  144. if i < win_size:
  145. window = previous_tokens[:, :win_size]
  146. else:
  147. window = previous_tokens[:, i - win_size : i]
  148. with torch.backends.cuda.sdp_kernel(
  149. enable_flash=False, enable_mem_efficient=False, enable_math=True
  150. ): # Actually better for Inductor to codegen attention here
  151. next_token = decode_one_token(
  152. model=model,
  153. x=cur_token,
  154. input_pos=input_pos,
  155. previous_tokens=window,
  156. **sampling_kwargs,
  157. )
  158. input_pos += 1
  159. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  160. previous_tokens[:, i : i + 1] = next_token.view(
  161. model.config.num_codebooks + 1, -1
  162. )
  163. # TODO: use tokenizer's eos
  164. if (cur_token[0, 0, -1] == eos_token_id).any():
  165. break
  166. return previous_tokens[:, : i + 1]
  167. @torch.no_grad()
  168. def generate(
  169. *,
  170. model: Transformer,
  171. prompt: torch.Tensor,
  172. max_new_tokens: int,
  173. eos_token_id: int = 2,
  174. precision: torch.dtype = torch.bfloat16,
  175. **sampling_kwargs,
  176. ) -> torch.Tensor:
  177. """
  178. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  179. """
  180. # create an empty tensor of the expected final shape and fill in the current tokens
  181. T = prompt.size(1)
  182. if max_new_tokens:
  183. if T + max_new_tokens > model.config.max_seq_len:
  184. max_new_tokens = model.config.max_seq_len - T
  185. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  186. T_new = T + max_new_tokens
  187. else:
  188. T_new = model.config.max_seq_len
  189. max_new_tokens = T_new - T
  190. device, dtype = prompt.device, prompt.dtype
  191. with torch.device(device):
  192. model.setup_caches(max_batch_size=1, max_seq_len=T_new, dtype=precision)
  193. codebook_dim = 1 + model.config.num_codebooks
  194. # create an empty tensor of the expected final shape and fill in the current tokens
  195. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  196. empty[:, :T] = prompt
  197. seq = empty
  198. input_pos = torch.arange(0, T, device=device)
  199. next_token = prefill(
  200. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  201. )
  202. seq[:, T : T + 1] = next_token
  203. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  204. x = decode_n_tokens(
  205. model,
  206. next_token.view(1, codebook_dim, -1),
  207. input_pos,
  208. max_new_tokens - 1,
  209. eos_token_id=eos_token_id,
  210. **sampling_kwargs,
  211. )
  212. # x = torch.cat(generated_tokens, dim=1)
  213. seq = seq[:, : T + 1 + x.size(1)]
  214. seq[:, T + 1 :] = x
  215. return seq
  216. def encode_tokens(
  217. tokenizer,
  218. string,
  219. bos=True,
  220. device="cuda",
  221. prompt_text=None,
  222. prompt_tokens=None,
  223. use_g2p=False,
  224. speaker=None,
  225. ):
  226. if prompt_text is not None:
  227. string = prompt_text + " " + string
  228. if use_g2p:
  229. prompt = g2p(string)
  230. prompt = [
  231. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  232. for _, i in prompt
  233. ]
  234. string = " ".join(prompt)
  235. else:
  236. string = clean_text(string)
  237. if speaker is not None:
  238. string = f"[SPK: {speaker}] {string}"
  239. string = f"[INST] {string} [/INST]"
  240. tokens = tokenizer.encode(
  241. string,
  242. max_length=10**6,
  243. add_special_tokens=bos,
  244. truncation=False,
  245. )
  246. tokens = torch.tensor([tokens], dtype=torch.int, device=device)
  247. # Codebooks
  248. zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
  249. prompt = torch.cat((tokens, zeros), dim=0)
  250. if prompt_tokens is None:
  251. return prompt
  252. # Get prompt tokens
  253. assert prompt_tokens.ndim == 2
  254. data = prompt_tokens + 2
  255. zeros = (
  256. torch.zeros((1, data.size(1)), dtype=torch.int, device=device)
  257. + tokenizer.pad_token_id
  258. ) # 32311 is the <pad> token
  259. data = torch.cat((zeros, data), dim=0)
  260. prompt = torch.cat((prompt, data), dim=1)
  261. return prompt
  262. def load_model(config_name, checkpoint_path, device, precision):
  263. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  264. cfg = compose(config_name=config_name)
  265. with torch.device("meta"):
  266. model: Transformer = instantiate(cfg.model.model)
  267. if "int8" in str(checkpoint_path):
  268. logger.info("Using int8 weight-only quantization!")
  269. from quantize import WeightOnlyInt8QuantHandler
  270. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  271. model = simple_quantizer.convert_for_runtime()
  272. if "int4" in str(checkpoint_path):
  273. logger.info("Using int4 quantization!")
  274. path_comps = checkpoint_path.name.split(".")
  275. assert path_comps[-2].startswith("g")
  276. groupsize = int(path_comps[-2][1:])
  277. from quantize import WeightOnlyInt4QuantHandler
  278. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  279. model = simple_quantizer.convert_for_runtime()
  280. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  281. if "state_dict" in checkpoint:
  282. checkpoint = checkpoint["state_dict"]
  283. if any(k.startswith("model.") for k in checkpoint):
  284. checkpoint = {
  285. k.replace("model.", ""): v
  286. for k, v in checkpoint.items()
  287. if k.startswith("model.")
  288. }
  289. model.load_state_dict(checkpoint, assign=True)
  290. model = model.to(device=device, dtype=precision)
  291. logger.info("Restored model from checkpoint")
  292. return model.eval()
  293. @click.command()
  294. @click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
  295. @click.option("--prompt-text", type=str, default=None)
  296. @click.option(
  297. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  298. )
  299. @click.option("--num-samples", type=int, default=1)
  300. @click.option("--max_new_tokens", type=int, default=0)
  301. @click.option("--top-k", type=int, default=None)
  302. @click.option("--top-p", type=float, default=0.5)
  303. @click.option("--repetition-penalty", type=float, default=1.5)
  304. @click.option("--temperature", type=float, default=0.7)
  305. @click.option(
  306. "--checkpoint-path",
  307. type=click.Path(path_type=Path, exists=True),
  308. default="results/text2semantic_400m_finetune/step_000002000.pth",
  309. )
  310. @click.option("--config-name", type=str, default="text2semantic_finetune")
  311. @click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
  312. @click.option("--compile/--no-compile", default=False)
  313. @click.option("--use-g2p/--no-g2p", default=True)
  314. @click.option("--seed", type=int, default=42)
  315. @click.option("--speaker", type=str, default=None)
  316. @click.option("--half/--no-half", default=False)
  317. def main(
  318. text: str,
  319. prompt_text: Optional[str],
  320. prompt_tokens: Optional[Path],
  321. num_samples: int,
  322. max_new_tokens: int,
  323. top_k: int,
  324. top_p: int,
  325. repetition_penalty: float,
  326. temperature: float,
  327. checkpoint_path: Path,
  328. config_name: str,
  329. tokenizer: str,
  330. compile: bool,
  331. use_g2p: bool,
  332. seed: int,
  333. speaker: Optional[str],
  334. half: bool,
  335. ) -> None:
  336. device = "cuda"
  337. precision = torch.half if half else torch.bfloat16
  338. logger.info("Loading model ...")
  339. t0 = time.time()
  340. model = load_model(config_name, checkpoint_path, device, precision)
  341. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  342. torch.cuda.synchronize()
  343. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  344. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  345. prompt_tokens = (
  346. torch.from_numpy(np.load(prompt_tokens)).to(device)
  347. if prompt_tokens is not None
  348. else None
  349. )
  350. encoded = encode_tokens(
  351. tokenizer,
  352. text,
  353. prompt_text=prompt_text,
  354. prompt_tokens=prompt_tokens,
  355. bos=True,
  356. device=device,
  357. use_g2p=use_g2p,
  358. speaker=speaker,
  359. )
  360. prompt_length = encoded.size(1)
  361. logger.info(f"Encoded prompt shape: {encoded.shape}")
  362. torch.manual_seed(seed)
  363. torch.cuda.manual_seed(seed)
  364. if compile:
  365. global decode_one_token
  366. decode_one_token = torch.compile(
  367. decode_one_token, mode="reduce-overhead", fullgraph=True
  368. )
  369. for i in range(num_samples):
  370. torch.cuda.synchronize()
  371. t0 = time.perf_counter()
  372. y = generate(
  373. model=model,
  374. prompt=encoded,
  375. max_new_tokens=max_new_tokens,
  376. eos_token_id=tokenizer.eos_token_id,
  377. precision=precision,
  378. temperature=temperature,
  379. top_k=top_k,
  380. top_p=top_p,
  381. repetition_penalty=repetition_penalty,
  382. )
  383. if i == 0 and compile:
  384. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  385. torch.cuda.synchronize()
  386. t = time.perf_counter() - t0
  387. tokens_generated = y.size(1) - prompt_length
  388. tokens_sec = tokens_generated / t
  389. logger.info(
  390. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  391. )
  392. logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
  393. logger.info(
  394. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  395. )
  396. codes = y[1:, prompt_length:-1]
  397. codes = codes - 2
  398. assert (codes >= 0).all(), "Codes should be >= 0"
  399. np.save(f"codes_{i}.npy", codes.cpu().numpy())
  400. logger.info(f"Saved codes to codes_{i}.npy")
  401. if __name__ == "__main__":
  402. main()