generate.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. import os
  2. import time
  3. from pathlib import Path
  4. from typing import Optional, Tuple
  5. import click
  6. import numpy as np
  7. import torch
  8. import torch._dynamo.config
  9. import torch._inductor.config
  10. from hydra import compose, initialize
  11. from hydra.utils import instantiate
  12. from loguru import logger
  13. from tqdm import tqdm
  14. from transformers import AutoTokenizer
  15. from fish_speech.text.parser import clean_text
  16. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  17. torch._inductor.config.coordinate_descent_tuning = True
  18. torch._inductor.config.triton.unique_kernel_names = True
  19. if hasattr(torch._inductor.config, "fx_graph_cache"):
  20. # Experimental feature to reduce compilation times, will be on by default in future
  21. torch._inductor.config.fx_graph_cache = True
  22. from fish_speech.models.text2semantic.llama import Transformer
  23. from fish_speech.text import g2p
  24. from fish_speech.text.symbols import pad as pad_symbol
  25. from fish_speech.text.symbols import pu_symbols
  26. def multinomial_sample_one_no_sync(
  27. probs_sort,
  28. ): # Does multinomial sampling without a cuda synchronization
  29. q = torch.empty_like(probs_sort).exponential_(1)
  30. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  31. def logits_to_probs(
  32. logits,
  33. previous_tokens: Optional[torch.Tensor] = None,
  34. temperature: float = 1.0,
  35. top_k: Optional[int] = None,
  36. top_p: Optional[int] = None,
  37. repetition_penalty: float = 1.0,
  38. ):
  39. if previous_tokens is not None and repetition_penalty != 1.0:
  40. previous_tokens = previous_tokens.long()
  41. score = torch.gather(logits, dim=0, index=previous_tokens)
  42. score = torch.where(
  43. score < 0, score * repetition_penalty, score / repetition_penalty
  44. )
  45. logits.scatter_(dim=0, index=previous_tokens, src=score)
  46. if top_p is not None and top_p < 1.0:
  47. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  48. cum_probs = torch.cumsum(
  49. torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
  50. )
  51. sorted_indices_to_remove = cum_probs > top_p
  52. sorted_indices_to_remove[0] = False # keep at least one option
  53. indices_to_remove = sorted_indices_to_remove.scatter(
  54. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  55. )
  56. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  57. logits = logits / max(temperature, 1e-5)
  58. if top_k is not None:
  59. v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
  60. pivot = v.select(-1, -1).unsqueeze(-1)
  61. logits = torch.where(logits < pivot, -float("Inf"), logits)
  62. probs = torch.nn.functional.softmax(logits, dim=-1)
  63. return probs
  64. def sample(
  65. logits,
  66. previous_tokens: Optional[torch.Tensor] = None,
  67. **sampling_kwargs,
  68. ) -> Tuple[torch.Tensor, torch.Tensor]:
  69. probs = logits_to_probs(
  70. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  71. )
  72. idx_next = multinomial_sample_one_no_sync(probs)
  73. return idx_next, probs
  74. def decode_one_token(
  75. model: Transformer,
  76. x: torch.Tensor,
  77. input_pos: torch.Tensor,
  78. previous_tokens: torch.Tensor = None,
  79. **sampling_kwargs,
  80. ) -> torch.Tensor:
  81. assert input_pos.shape[-1] == 1
  82. logits = model.forward_generate(x, input_pos)
  83. codebooks = [
  84. sample(
  85. logits.token_logits,
  86. previous_tokens=None, # Disable repetition penalty for the token codebook
  87. **sampling_kwargs,
  88. )[0]
  89. ]
  90. # Disable <s> and </s> tokens for codebooks
  91. if model.config.num_codebooks != 0:
  92. for i in range(model.config.num_codebooks):
  93. codebooks.append(
  94. torch.argmax(logits.codebook_logits[:, :, i], dim=-1).view(1)
  95. )
  96. return torch.stack(codebooks, dim=0)
  97. def prefill(
  98. model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
  99. ) -> torch.Tensor:
  100. # input_pos: [B, S]
  101. logits = model.forward_generate(x, input_pos)
  102. codebooks = [
  103. sample(
  104. logits.token_logits,
  105. previous_tokens=None,
  106. **sampling_kwargs,
  107. )[0]
  108. ]
  109. if model.config.num_codebooks != 0:
  110. for i in range(model.config.num_codebooks):
  111. codebooks.append(
  112. torch.argmax(logits.codebook_logits[:, :, i], dim=-1).view(1)
  113. )
  114. return torch.stack(codebooks, dim=0)
  115. def decode_n_tokens(
  116. model: Transformer,
  117. cur_token: torch.Tensor,
  118. input_pos: torch.Tensor,
  119. num_new_tokens: int,
  120. eos_token_id: int = 2,
  121. **sampling_kwargs,
  122. ):
  123. previous_tokens = torch.zeros(
  124. (model.config.num_codebooks + 1, model.config.max_seq_len),
  125. dtype=torch.int,
  126. device=cur_token.device,
  127. )
  128. for i in tqdm(range(num_new_tokens)):
  129. # We need to get windowed repeat penalty
  130. win_size = 16
  131. if i < win_size:
  132. window = previous_tokens[:, :win_size]
  133. else:
  134. window = previous_tokens[:, i - win_size : i]
  135. with torch.backends.cuda.sdp_kernel(
  136. enable_flash=False, enable_mem_efficient=False, enable_math=True
  137. ): # Actually better for Inductor to codegen attention here
  138. next_token = decode_one_token(
  139. model=model,
  140. x=cur_token,
  141. input_pos=input_pos,
  142. previous_tokens=window,
  143. **sampling_kwargs,
  144. )
  145. input_pos += 1
  146. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  147. previous_tokens[:, i : i + 1] = next_token.view(
  148. model.config.num_codebooks + 1, -1
  149. )
  150. # TODO: use tokenizer's eos
  151. if cur_token[0, 0, -1] == eos_token_id or (cur_token[0, 1:, -1] == 1).any():
  152. break
  153. return previous_tokens[:, : i + 1]
  154. @torch.no_grad()
  155. def generate(
  156. *,
  157. model: Transformer,
  158. prompt: torch.Tensor,
  159. max_new_tokens: int,
  160. eos_token_id: int = 2,
  161. precision: torch.dtype = torch.bfloat16,
  162. **sampling_kwargs,
  163. ) -> torch.Tensor:
  164. """
  165. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  166. """
  167. # create an empty tensor of the expected final shape and fill in the current tokens
  168. T = prompt.size(1)
  169. if max_new_tokens:
  170. if T + max_new_tokens > model.config.max_seq_len:
  171. max_new_tokens = model.config.max_seq_len - T
  172. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  173. T_new = T + max_new_tokens
  174. else:
  175. T_new = model.config.max_seq_len
  176. max_new_tokens = T_new - T
  177. device, dtype = prompt.device, prompt.dtype
  178. with torch.device(device):
  179. model.setup_caches(max_batch_size=1, max_seq_len=T_new, dtype=precision)
  180. codebook_dim = 1 + model.config.num_codebooks
  181. # create an empty tensor of the expected final shape and fill in the current tokens
  182. empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
  183. empty[:, :T] = prompt
  184. seq = empty
  185. input_pos = torch.arange(0, T, device=device)
  186. next_token = prefill(
  187. model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
  188. )
  189. seq[:, T : T + 1] = next_token
  190. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  191. x = decode_n_tokens(
  192. model,
  193. next_token.view(1, codebook_dim, -1),
  194. input_pos,
  195. max_new_tokens - 1,
  196. eos_token_id=eos_token_id,
  197. **sampling_kwargs,
  198. )
  199. # x = torch.cat(generated_tokens, dim=1)
  200. seq = seq[:, : T + 1 + x.size(1)]
  201. seq[:, T + 1 :] = x
  202. return seq
  203. def encode_tokens(
  204. tokenizer,
  205. string,
  206. bos=True,
  207. device="cuda",
  208. prompt_tokens=None,
  209. use_g2p=False,
  210. speaker=None,
  211. order="zh,jp,en",
  212. num_codebooks=4,
  213. ):
  214. if use_g2p:
  215. order = order.split(",")
  216. prompt = g2p(string, order=order)
  217. prompt = [
  218. (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
  219. for _, i in prompt
  220. ]
  221. string = " ".join(prompt)
  222. else:
  223. string = clean_text(string)
  224. if speaker is not None:
  225. string = f"[SPK: {speaker}] {string}"
  226. string = f"[INST] {string} [/INST]"
  227. # Handle English less frequent words
  228. # TODO: update tokenizer to handle this
  229. # sub_strings = string.split(" ")
  230. # new_tokens = []
  231. # for i, string in enumerate(sub_strings):
  232. # tokens = tokenizer.encode(
  233. # string,
  234. # add_special_tokens=i == 0 and bos,
  235. # max_length=10**6,
  236. # truncation=False,
  237. # )
  238. # new_tokens.extend(tokens)
  239. new_tokens = tokenizer.encode(
  240. string,
  241. add_special_tokens=bos,
  242. max_length=10**6,
  243. truncation=False,
  244. )
  245. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  246. # Codebooks
  247. zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  248. prompt = torch.cat((tokens, zeros), dim=0)
  249. if prompt_tokens is None:
  250. return prompt
  251. # Get prompt tokens
  252. if prompt_tokens.ndim == 3:
  253. assert (
  254. prompt_tokens.shape[0] == 1
  255. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  256. prompt_tokens = prompt_tokens[0]
  257. assert prompt_tokens.ndim == 2
  258. data = prompt_tokens + 2
  259. if prompt_tokens.shape[0] > num_codebooks:
  260. logger.warning(
  261. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  262. )
  263. data = data[:num_codebooks]
  264. # Since 1.0, we use <s:xxx> to replace <semantic>
  265. main_token_ids = torch.tensor(
  266. [[tokenizer.pad_token_id] * data.size(1)], dtype=torch.int, device=device
  267. )
  268. data = torch.cat((main_token_ids, data), dim=0)
  269. prompt = torch.cat((prompt, data), dim=1)
  270. return prompt
  271. def load_model(config_name, checkpoint_path, device, precision):
  272. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  273. cfg = compose(config_name=config_name)
  274. model: Transformer = instantiate(cfg.model).model
  275. if "int8" in str(checkpoint_path):
  276. logger.info("Using int8 weight-only quantization!")
  277. from quantize import WeightOnlyInt8QuantHandler
  278. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  279. model = simple_quantizer.convert_for_runtime()
  280. if "int4" in str(checkpoint_path):
  281. logger.info("Using int4 quantization!")
  282. path_comps = checkpoint_path.name.split(".")
  283. assert path_comps[-2].startswith("g")
  284. groupsize = int(path_comps[-2][1:])
  285. from quantize import WeightOnlyInt4QuantHandler
  286. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  287. model = simple_quantizer.convert_for_runtime()
  288. checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
  289. if "state_dict" in checkpoint:
  290. checkpoint = checkpoint["state_dict"]
  291. if any(k.startswith("model.") for k in checkpoint):
  292. checkpoint = {
  293. k.replace("model.", ""): v
  294. for k, v in checkpoint.items()
  295. if k.startswith("model.")
  296. }
  297. model.load_state_dict(checkpoint, assign=True)
  298. model = model.to(device=device, dtype=precision)
  299. logger.info("Restored model from checkpoint")
  300. return model.eval(), cfg
  301. def split_text(text, min_length):
  302. text = clean_text(text)
  303. segments = []
  304. curr = ""
  305. for char in text:
  306. curr += char
  307. if char not in [".", ",", "!", "?"]:
  308. continue
  309. if len(curr) >= min_length:
  310. segments.append(curr)
  311. curr = ""
  312. if curr:
  313. segments.append(curr)
  314. return segments
  315. @click.command()
  316. @click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
  317. @click.option("--prompt-text", type=str, default=None)
  318. @click.option(
  319. "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
  320. )
  321. @click.option("--num-samples", type=int, default=1)
  322. @click.option("--max-new-tokens", type=int, default=0)
  323. @click.option("--top-k", type=int, default=None)
  324. @click.option("--top-p", type=float, default=0.5)
  325. @click.option("--repetition-penalty", type=float, default=1.5)
  326. @click.option("--temperature", type=float, default=0.7)
  327. @click.option(
  328. "--checkpoint-path",
  329. type=click.Path(path_type=Path, exists=True),
  330. default="results/text2semantic_400m_finetune/step_000002000.pth",
  331. )
  332. @click.option("--config-name", type=str, default="text2semantic_finetune")
  333. @click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
  334. @click.option("--compile/--no-compile", default=False)
  335. @click.option("--use-g2p/--no-g2p", default=True)
  336. @click.option("--seed", type=int, default=42)
  337. @click.option("--speaker", type=str, default=None)
  338. @click.option("--order", type=str, default="zh,jp,en")
  339. @click.option("--half/--no-half", default=False)
  340. @click.option("--iterative-prompt/--no-iterative-prompt", default=False)
  341. def main(
  342. text: str,
  343. prompt_text: Optional[str],
  344. prompt_tokens: Optional[Path],
  345. num_samples: int,
  346. max_new_tokens: int,
  347. top_k: int,
  348. top_p: int,
  349. repetition_penalty: float,
  350. temperature: float,
  351. checkpoint_path: Path,
  352. config_name: str,
  353. tokenizer: str,
  354. compile: bool,
  355. use_g2p: bool,
  356. seed: int,
  357. speaker: Optional[str],
  358. order: str,
  359. half: bool,
  360. iterative_prompt: bool,
  361. ) -> None:
  362. device = "cuda"
  363. precision = torch.half if half else torch.bfloat16
  364. logger.info("Loading model ...")
  365. t0 = time.time()
  366. model, cfg = load_model(config_name, checkpoint_path, device, precision)
  367. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  368. torch.cuda.synchronize()
  369. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  370. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  371. prompt_tokens = (
  372. torch.from_numpy(np.load(prompt_tokens)).to(device)
  373. if prompt_tokens is not None
  374. else None
  375. )
  376. use_prompt = prompt_text is not None and prompt_tokens is not None
  377. encoded = []
  378. texts = split_text(text, 20) if iterative_prompt else [text]
  379. for idx, text in enumerate(texts):
  380. encoded.append(
  381. encode_tokens(
  382. tokenizer,
  383. string=text,
  384. bos=idx == 0 and not use_prompt,
  385. device=device,
  386. use_g2p=use_g2p,
  387. speaker=None,
  388. order=order,
  389. num_codebooks=model.config.num_codebooks,
  390. )
  391. )
  392. print(f"Encoded text: {text}")
  393. if use_prompt and iterative_prompt:
  394. encoded_prompt = encode_tokens(
  395. tokenizer,
  396. prompt_text,
  397. prompt_tokens=prompt_tokens,
  398. bos=True,
  399. device=device,
  400. use_g2p=use_g2p,
  401. speaker=speaker,
  402. order=order,
  403. num_codebooks=model.config.num_codebooks,
  404. )
  405. encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
  406. # prompt_length = encoded.size(1)
  407. # logger.info(f"Encoded prompt shape: {encoded.shape}")
  408. torch.manual_seed(seed)
  409. torch.cuda.manual_seed(seed)
  410. if compile:
  411. global decode_one_token
  412. decode_one_token = torch.compile(
  413. decode_one_token, mode="reduce-overhead", fullgraph=True
  414. )
  415. for idx in range(num_samples):
  416. torch.cuda.synchronize()
  417. global_encoded = []
  418. all_codes = []
  419. seg_idx = 0
  420. while seg_idx < len(encoded):
  421. seg = encoded[seg_idx]
  422. global_encoded.append(seg)
  423. cat_encoded = torch.cat(global_encoded, dim=1)
  424. prompt_length = cat_encoded.size(1)
  425. t0 = time.perf_counter()
  426. y = generate(
  427. model=model,
  428. prompt=cat_encoded,
  429. max_new_tokens=max_new_tokens,
  430. eos_token_id=tokenizer.eos_token_id,
  431. precision=precision,
  432. temperature=temperature,
  433. top_k=top_k,
  434. top_p=top_p,
  435. repetition_penalty=repetition_penalty,
  436. )
  437. if idx == 0 and compile:
  438. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  439. torch.cuda.synchronize()
  440. t = time.perf_counter() - t0
  441. tokens_generated = y.size(1) - prompt_length
  442. tokens_sec = tokens_generated / t
  443. logger.info(
  444. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  445. )
  446. logger.info(
  447. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  448. )
  449. logger.info(
  450. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  451. )
  452. # Put the generated tokens
  453. codes = y[1:, prompt_length:-1].clone()
  454. if getattr(cfg, "use_delay_pattern", True):
  455. new_codes = []
  456. for j, code in enumerate(codes):
  457. new_codes.append(
  458. code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
  459. )
  460. codes = torch.stack(new_codes, dim=0)
  461. codes = codes - 2
  462. if not (codes >= 0).all():
  463. global_encoded.pop()
  464. logger.warning(f"Negative code found: {codes}, retrying ...")
  465. continue
  466. global_encoded.append(y[:, prompt_length:-1].clone())
  467. all_codes.append(codes)
  468. seg_idx += 1
  469. codes = torch.cat(all_codes, dim=1)
  470. assert (codes >= 0).all(), f"Negative code found: {codes}"
  471. print(codes)
  472. np.save(f"codes_{idx}.npy", codes.cpu().numpy())
  473. logger.info(f"Saved codes to codes_{idx}.npy")
  474. if __name__ == "__main__":
  475. main()