generate.py 33 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import hydra
  11. import numpy as np
  12. import torch
  13. import torch._dynamo.config
  14. import torch._inductor.config
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from transformers import AutoTokenizer
  18. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  19. from fish_speech.models.text2semantic.llama import BaseModelArgs
  20. from fish_speech.text import clean_text, split_text
  21. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  22. torch._inductor.config.coordinate_descent_tuning = True
  23. torch._inductor.config.triton.unique_kernel_names = True
  24. if hasattr(torch._inductor.config, "fx_graph_cache"):
  25. # Experimental feature to reduce compilation times, will be on by default in future
  26. torch._inductor.config.fx_graph_cache = True
  27. from torch.nn.attention import SDPBackend, sdpa_kernel
  28. from fish_speech.models.text2semantic.llama import (
  29. BaseTransformer,
  30. DualARTransformer,
  31. NaiveTransformer,
  32. )
  33. def multinomial_sample_one_no_sync(
  34. probs_sort,
  35. ): # Does multinomial sampling without a cuda synchronization
  36. q = torch.empty_like(probs_sort).exponential_(1)
  37. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  38. def logits_to_probs(
  39. logits,
  40. previous_tokens: Optional[torch.Tensor] = None,
  41. temperature: torch.Tensor = 1.0,
  42. top_p: torch.Tensor = 1.0,
  43. repetition_penalty: torch.Tensor = 1.0,
  44. ) -> torch.Tensor:
  45. # Apply repetition penalty
  46. if previous_tokens is not None:
  47. previous_tokens = previous_tokens.long()
  48. score = torch.gather(logits, dim=0, index=previous_tokens)
  49. score = torch.where(
  50. score < 0, score * repetition_penalty, score / repetition_penalty
  51. )
  52. logits.scatter_(dim=0, index=previous_tokens, src=score)
  53. # Apply top-p sampling
  54. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  55. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  56. sorted_indices_to_remove = cum_probs > top_p
  57. sorted_indices_to_remove[0] = False # keep at least one option
  58. indices_to_remove = sorted_indices_to_remove.scatter(
  59. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  60. )
  61. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  62. logits = logits / max(temperature, 1e-5)
  63. probs = torch.nn.functional.softmax(logits, dim=-1)
  64. return probs
  65. def multinomial_sample_one_no_sync_agent(
  66. probs_sort,
  67. ): # Does multinomial sampling without a cuda synchronization
  68. q = torch.empty_like(probs_sort).exponential_(1)
  69. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  70. def logits_to_probs_agent(
  71. logits,
  72. previous_tokens: Optional[torch.Tensor] = None,
  73. temperature: torch.Tensor = 1.0,
  74. top_p: torch.Tensor = 1.0,
  75. repetition_penalty: torch.Tensor = 1.0,
  76. ) -> torch.Tensor:
  77. # Apply repetition penalty
  78. if previous_tokens is not None:
  79. previous_tokens = previous_tokens.long()
  80. score = torch.gather(logits, dim=-1, index=previous_tokens)
  81. score = torch.where(
  82. score < 0, score * repetition_penalty, score / repetition_penalty
  83. )
  84. logits.scatter_(dim=-1, index=previous_tokens, src=score)
  85. # Apply top-p sampling
  86. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  87. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  88. sorted_indices_to_remove = cum_probs > top_p
  89. sorted_indices_to_remove[..., 0] = False # keep at least one option
  90. indices_to_remove = sorted_indices_to_remove.scatter(
  91. dim=-1, index=sorted_indices, src=sorted_indices_to_remove
  92. )
  93. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  94. logits = logits / max(temperature, 1e-5)
  95. probs = torch.nn.functional.softmax(logits, dim=-1)
  96. return probs
  97. def sample(
  98. logits,
  99. previous_tokens: Optional[torch.Tensor] = None,
  100. **sampling_kwargs,
  101. ) -> Tuple[torch.Tensor, torch.Tensor]:
  102. probs = logits_to_probs(
  103. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  104. )
  105. idx_next = multinomial_sample_one_no_sync(probs)
  106. return idx_next, probs
  107. def sample_agent(
  108. logits,
  109. previous_tokens: Optional[torch.Tensor] = None,
  110. **sampling_kwargs,
  111. ) -> Tuple[torch.Tensor, torch.Tensor]:
  112. probs = logits_to_probs_agent(
  113. logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
  114. )
  115. idx_next = multinomial_sample_one_no_sync_agent(probs)
  116. return idx_next, probs
  117. def decode_one_token_ar_agent(
  118. model: DualARTransformer,
  119. x: torch.Tensor,
  120. input_pos: torch.Tensor,
  121. previous_tokens: torch.Tensor = None,
  122. semantic_id: int = 32003,
  123. **sampling_kwargs,
  124. ) -> torch.Tensor:
  125. # print(x, input_pos)
  126. x = model.forward_generate(x, input_pos)
  127. logits = x.logits # [:, -1:]
  128. hidden_states = x.hidden_states # [:, -1:]
  129. sampling_kwargs_main = sampling_kwargs.copy()
  130. sampling_kwargs_main["temperature"] = 0.1
  131. sampling_kwargs_main["top_p"] = 0.1
  132. sampling_kwargs_main["repetition_penalty"] = 1.0
  133. codebooks = [
  134. sample_agent(
  135. logits,
  136. previous_tokens=None, # Disable repetition penalty for the token codebook
  137. **sampling_kwargs_main,
  138. )[0]
  139. ]
  140. # Cleanup the cache
  141. for layer in model.fast_layers:
  142. layer.attention.kv_cache.k_cache.fill_(0)
  143. layer.attention.kv_cache.v_cache.fill_(0)
  144. for codebook_idx in range(model.config.num_codebooks):
  145. input_pos = torch.tensor(
  146. [codebook_idx], device=hidden_states.device, dtype=torch.long
  147. )
  148. logits = model.forward_generate_fast(hidden_states, input_pos)
  149. a = sample_agent(
  150. logits,
  151. previous_tokens=(
  152. previous_tokens[:, codebook_idx + 1]
  153. if previous_tokens is not None
  154. else None
  155. ),
  156. **sampling_kwargs,
  157. )[0]
  158. hidden_states = model.fast_embeddings(a)
  159. codebooks.append(a)
  160. codebooks = torch.stack(codebooks, dim=1)
  161. codebooks[:, 1:, :] = torch.masked_fill(
  162. codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
  163. )
  164. # for i in range(codebooks.size(1) - 1):
  165. # codebooks[:, i + 1, :] = torch.masked_fill(
  166. # codebooks[:, i + 1, :],
  167. # codebooks[:, :1, :] != semantic_id,
  168. # CODEBOOK_PAD_TOKEN_ID + i * 1024,
  169. # )
  170. # print(codebooks)
  171. return codebooks
  172. def decode_one_token_naive_agent(
  173. model: NaiveTransformer,
  174. x: torch.Tensor,
  175. input_pos: torch.Tensor,
  176. previous_tokens: torch.Tensor = None,
  177. semantic_id: int = 32003,
  178. **sampling_kwargs,
  179. ) -> torch.Tensor:
  180. x = model.forward_generate(x, input_pos)
  181. codebooks = [
  182. sample(
  183. x.token_logits,
  184. previous_tokens=None, # Disable repetition penalty for the token codebook
  185. **sampling_kwargs,
  186. )[0]
  187. ]
  188. for i in range(model.config.num_codebooks):
  189. codebooks.append(
  190. sample_agent(
  191. x.codebook_logits[:, :, i],
  192. previous_tokens=(
  193. previous_tokens[:, i + 1] if previous_tokens is not None else None
  194. ),
  195. **sampling_kwargs,
  196. )[0]
  197. )
  198. codebooks = torch.stack(codebooks, dim=1)
  199. codebooks[:, 1:, :] = torch.masked_fill(
  200. codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
  201. )
  202. return codebooks
  203. def decode_one_token_ar(
  204. model: DualARTransformer,
  205. x: torch.Tensor,
  206. input_pos: torch.Tensor,
  207. previous_tokens: torch.Tensor = None,
  208. semantic_id: int = 0,
  209. **sampling_kwargs,
  210. ) -> torch.Tensor:
  211. x = model.forward_generate(x, input_pos)
  212. sampling_kwargs_main = sampling_kwargs.copy()
  213. # sampling_kwargs_main["temperature"] = 0.1
  214. # sampling_kwargs_main["top_p"] = 0.1
  215. # sampling_kwargs_main["repetition_penalty"] = 1.0
  216. codebooks = [
  217. sample(
  218. x.logits,
  219. previous_tokens=None, # Disable repetition penalty for the token codebook
  220. **sampling_kwargs_main,
  221. )[0]
  222. ]
  223. x = x.hidden_states
  224. # Cleanup the cache
  225. for layer in model.fast_layers:
  226. layer.attention.kv_cache.k_cache.fill_(0)
  227. layer.attention.kv_cache.v_cache.fill_(0)
  228. for codebook_idx in range(model.config.num_codebooks):
  229. input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
  230. logits = model.forward_generate_fast(x, input_pos)
  231. a = sample(
  232. logits,
  233. previous_tokens=(
  234. previous_tokens[codebook_idx + 1]
  235. if previous_tokens is not None
  236. else None
  237. ),
  238. **sampling_kwargs,
  239. )[0]
  240. x = model.fast_embeddings(a)
  241. codebooks.append(a)
  242. codebooks = torch.stack(codebooks, dim=0)
  243. codebooks[1:, :] = torch.masked_fill(
  244. codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
  245. )
  246. return codebooks
  247. def decode_one_token_naive(
  248. model: NaiveTransformer,
  249. x: torch.Tensor,
  250. input_pos: torch.Tensor,
  251. previous_tokens: torch.Tensor = None,
  252. **sampling_kwargs,
  253. ) -> torch.Tensor:
  254. x = model.forward_generate(x, input_pos)
  255. sampling_kwargs_main = sampling_kwargs.copy()
  256. sampling_kwargs_main["temperature"] = 0.1
  257. sampling_kwargs_main["top_p"] = 0.1
  258. sampling_kwargs_main["repetition_penalty"] = 1.0
  259. codebooks = [
  260. sample(
  261. x.logits,
  262. previous_tokens=None, # Disable repetition penalty for the token codebook
  263. **sampling_kwargs_main,
  264. )[0]
  265. ]
  266. for i in range(model.config.num_codebooks):
  267. codebooks.append(
  268. sample(
  269. x.codebook_logits[:, :, i],
  270. previous_tokens=(
  271. previous_tokens[i + 1] if previous_tokens is not None else None
  272. ),
  273. **sampling_kwargs,
  274. )[0]
  275. )
  276. return torch.stack(codebooks, dim=0)
  277. def decode_n_tokens(
  278. model: NaiveTransformer,
  279. cur_token: torch.Tensor,
  280. input_pos: torch.Tensor,
  281. num_new_tokens: int,
  282. im_end_id: int = 4,
  283. decode_one_token=decode_one_token_naive,
  284. semantic_id: int = 0,
  285. **sampling_kwargs,
  286. ):
  287. previous_tokens = torch.zeros(
  288. (model.config.num_codebooks + 1, model.config.max_seq_len),
  289. dtype=torch.int,
  290. device=cur_token.device,
  291. )
  292. for i in tqdm(range(num_new_tokens)):
  293. # We need to get windowed repeat penalty
  294. win_size = 16
  295. if i < win_size:
  296. window = previous_tokens[:, :win_size]
  297. else:
  298. window = previous_tokens[:, i - win_size : i]
  299. with (
  300. torch.backends.cuda.sdp_kernel(
  301. enable_flash=False, enable_mem_efficient=False, enable_math=True
  302. )
  303. if torch.cuda.is_available()
  304. else nullcontext()
  305. ): # Actually better for Inductor to codegen attention here
  306. next_token = decode_one_token(
  307. model=model,
  308. x=cur_token,
  309. input_pos=input_pos,
  310. previous_tokens=window,
  311. semantic_id=semantic_id,
  312. **sampling_kwargs,
  313. )
  314. input_pos += 1
  315. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  316. previous_tokens[:, i : i + 1] = next_token.view(
  317. model.config.num_codebooks + 1, -1
  318. )
  319. if cur_token[0, 0, -1] == im_end_id:
  320. break
  321. return previous_tokens[:, : i + 1]
  322. @torch.no_grad()
  323. @torch.inference_mode()
  324. def generate(
  325. *,
  326. model: NaiveTransformer,
  327. prompt: torch.Tensor,
  328. max_new_tokens: int,
  329. im_end_id: int = 4,
  330. decode_one_token=decode_one_token_naive,
  331. **sampling_kwargs,
  332. ) -> torch.Tensor:
  333. """
  334. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  335. """
  336. # create an empty tensor of the expected final shape and fill in the current tokens
  337. T = prompt.size(1)
  338. semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
  339. if max_new_tokens:
  340. if T + max_new_tokens > model.config.max_seq_len:
  341. max_new_tokens = model.config.max_seq_len - T
  342. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  343. T_new = T + max_new_tokens
  344. else:
  345. T_new = model.config.max_seq_len
  346. max_new_tokens = T_new - T
  347. device, dtype = prompt.device, prompt.dtype
  348. codebook_dim = 1 + model.config.num_codebooks
  349. # create an empty tensor of the expected final shape and fill in the current tokens
  350. empty = torch.empty(
  351. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  352. )
  353. empty[:, :T] = prompt
  354. seq = empty
  355. input_pos = torch.arange(0, T, device=device)
  356. # Use non-accelerated version for now, to avoid compilation overhead
  357. prefill_decode = (
  358. decode_one_token_naive
  359. if isinstance(model, NaiveTransformer)
  360. else decode_one_token_ar
  361. )
  362. next_token = prefill_decode(
  363. model,
  364. prompt.view(1, codebook_dim, -1),
  365. input_pos,
  366. semantic_id=semantic_id,
  367. **sampling_kwargs,
  368. )
  369. seq[:, T : T + 1] = next_token
  370. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  371. x = decode_n_tokens(
  372. model,
  373. next_token.view(1, codebook_dim, -1),
  374. input_pos,
  375. max_new_tokens - 1,
  376. im_end_id=im_end_id,
  377. decode_one_token=decode_one_token,
  378. semantic_id=semantic_id,
  379. **sampling_kwargs,
  380. )
  381. # x = torch.cat(generated_tokens, dim=1)
  382. seq = seq[:, : T + 1 + x.size(1)]
  383. seq[:, T + 1 :] = x
  384. return seq
  385. def decode_n_tokens_agent(
  386. model: NaiveTransformer,
  387. cur_token: torch.Tensor,
  388. input_pos: torch.Tensor,
  389. num_new_tokens: int,
  390. im_end_id: int = 4,
  391. semantic_id: int = 32003,
  392. decode_one_token=decode_one_token_naive_agent,
  393. early_stop_threshold: float = 0.6,
  394. **sampling_kwargs,
  395. ):
  396. batch_size = cur_token.size(0)
  397. previous_tokens = torch.zeros(
  398. (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
  399. dtype=torch.int,
  400. device=cur_token.device,
  401. )
  402. finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
  403. finished = finished | (cur_token[:, 0, -1] == im_end_id)
  404. start_time = time.time()
  405. for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
  406. # We need to get windowed repeat penalty
  407. win_size = 16
  408. if i < win_size:
  409. window = previous_tokens[:, :, :win_size]
  410. else:
  411. window = previous_tokens[:, :, i - win_size : i]
  412. with sdpa_kernel(
  413. SDPBackend.MATH
  414. ): # Actually better for Inductor to codegen attention here
  415. next_token = decode_one_token(
  416. model=model,
  417. x=cur_token,
  418. input_pos=input_pos,
  419. previous_tokens=window,
  420. semantic_id=semantic_id,
  421. **sampling_kwargs,
  422. )
  423. input_pos += 1
  424. cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
  425. previous_tokens[:, :, i : i + 1] = next_token.view(
  426. batch_size, model.config.num_codebooks + 1, -1
  427. )
  428. yield cur_token.cpu()
  429. finished = finished | (cur_token[:, 0, -1] == im_end_id)
  430. if finished.all() or (
  431. 0 < early_stop_threshold < 1
  432. and finished.sum() >= round(batch_size * early_stop_threshold)
  433. ):
  434. break
  435. total_time = time.time() - start_time
  436. generated_tokens = i + 1
  437. tokens_per_second = (generated_tokens / total_time) * batch_size
  438. logger.info(
  439. f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
  440. )
  441. @torch.no_grad()
  442. @torch.inference_mode()
  443. def generate_agent(
  444. *,
  445. model: BaseTransformer,
  446. prompt: torch.Tensor,
  447. max_new_tokens: int,
  448. im_end_id: int = 4,
  449. semantic_id: int = 32003,
  450. decode_one_token=decode_one_token_naive_agent,
  451. num_samples: int = 1,
  452. early_stop_threshold: float = 0.6,
  453. **sampling_kwargs,
  454. ):
  455. """
  456. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  457. """
  458. # create an empty tensor of the expected final shape and fill in the current tokens
  459. T = prompt.size(1)
  460. prompt = prompt[None].repeat(num_samples, 1, 1)
  461. if T >= model.config.max_seq_len:
  462. raise ValueError(
  463. f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
  464. )
  465. if max_new_tokens:
  466. if T + max_new_tokens > model.config.max_seq_len:
  467. max_new_tokens = model.config.max_seq_len - T
  468. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  469. T_new = T + max_new_tokens
  470. else:
  471. T_new = model.config.max_seq_len
  472. max_new_tokens = T_new - T
  473. device, dtype = prompt.device, prompt.dtype
  474. codebook_dim = 1 + model.config.num_codebooks
  475. input_pos = torch.arange(0, T, device=device)
  476. # Use non-accelerated version for now, to avoid compilation overhead
  477. prefill_decode = (
  478. decode_one_token_naive_agent
  479. if isinstance(model, NaiveTransformer)
  480. else decode_one_token_ar_agent
  481. )
  482. next_token = prefill_decode(
  483. model,
  484. prompt,
  485. input_pos,
  486. semantic_id=semantic_id,
  487. **sampling_kwargs,
  488. ).view(num_samples, codebook_dim, -1)
  489. yield next_token.cpu()
  490. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  491. yield from decode_n_tokens_agent(
  492. model,
  493. next_token,
  494. input_pos,
  495. max_new_tokens - 1,
  496. im_end_id=im_end_id,
  497. semantic_id=semantic_id,
  498. decode_one_token=decode_one_token,
  499. early_stop_threshold=early_stop_threshold,
  500. **sampling_kwargs,
  501. )
  502. def encode_tokens(
  503. tokenizer,
  504. string,
  505. device="cuda",
  506. prompt_tokens=None,
  507. num_codebooks=4,
  508. ):
  509. string = clean_text(string)
  510. string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
  511. new_tokens = tokenizer.encode(
  512. string,
  513. add_special_tokens=False,
  514. max_length=10**6,
  515. truncation=False,
  516. )
  517. tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
  518. # Codebooks
  519. zeros = (
  520. torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
  521. * CODEBOOK_PAD_TOKEN_ID
  522. )
  523. prompt = torch.cat((tokens, zeros), dim=0)
  524. if prompt_tokens is None:
  525. return prompt
  526. # Get prompt tokens
  527. if prompt_tokens.ndim == 3:
  528. assert (
  529. prompt_tokens.shape[0] == 1
  530. ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
  531. prompt_tokens = prompt_tokens[0]
  532. assert prompt_tokens.ndim == 2
  533. data = prompt_tokens + 1
  534. if prompt_tokens.shape[0] > num_codebooks:
  535. logger.warning(
  536. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  537. )
  538. data = data[:num_codebooks]
  539. # Add pad token for each codebook
  540. data = torch.cat(
  541. (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
  542. dim=1,
  543. )
  544. # Since 1.0, we use <|semantic|>
  545. s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
  546. end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  547. main_token_ids = (
  548. torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
  549. )
  550. main_token_ids[0, -1] = end_token_id
  551. data = torch.cat((main_token_ids, data), dim=0)
  552. prompt = torch.cat((prompt, data), dim=1)
  553. return prompt
  554. def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
  555. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  556. checkpoint_path, load_weights=True
  557. )
  558. model = model.to(device=device, dtype=precision)
  559. logger.info(f"Restored model from checkpoint")
  560. if isinstance(model, DualARTransformer):
  561. decode_one_token = (
  562. decode_one_token_ar_agent if is_agent else decode_one_token_ar
  563. )
  564. logger.info("Using DualARTransformer")
  565. else:
  566. decode_one_token = (
  567. decode_one_token_naive_agent if is_agent else decode_one_token_naive
  568. )
  569. logger.info("Using NaiveTransformer")
  570. if compile:
  571. logger.info("Compiling function...")
  572. decode_one_token = torch.compile(
  573. decode_one_token,
  574. fullgraph=True,
  575. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  576. mode="reduce-overhead" if torch.cuda.is_available() else None,
  577. )
  578. return model.eval(), decode_one_token
  579. @dataclass
  580. class GenerateResponse:
  581. action: Literal["sample", "next"]
  582. codes: Optional[torch.Tensor] = None
  583. text: Optional[str] = None
  584. def generate_long(
  585. *,
  586. model,
  587. device: str | torch.device,
  588. decode_one_token: callable,
  589. text: str,
  590. num_samples: int = 1,
  591. max_new_tokens: int = 0,
  592. top_p: int = 0.7,
  593. repetition_penalty: float = 1.5,
  594. temperature: float = 0.7,
  595. compile: bool = False,
  596. iterative_prompt: bool = True,
  597. max_length: int = 2048,
  598. chunk_length: int = 150,
  599. prompt_text: Optional[str | list[str]] = None,
  600. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  601. ):
  602. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  603. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  604. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  605. use_prompt = prompt_text is not None and prompt_tokens is not None
  606. if use_prompt and isinstance(prompt_text, str):
  607. prompt_text = [prompt_text]
  608. prompt_tokens = [prompt_tokens]
  609. assert use_prompt is False or len(prompt_text) == len(
  610. prompt_tokens
  611. ), "Prompt text and tokens must have the same length"
  612. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  613. tokenizer = model.tokenizer
  614. im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
  615. encoded = []
  616. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  617. encoded_prompts = []
  618. if use_prompt:
  619. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  620. encoded_prompts.append(
  621. encode_tokens(
  622. tokenizer,
  623. string=t,
  624. device=device,
  625. prompt_tokens=c,
  626. num_codebooks=model.config.num_codebooks,
  627. )
  628. )
  629. for idx, text in enumerate(texts):
  630. encoded.append(
  631. encode_tokens(
  632. tokenizer,
  633. string=text,
  634. device=device,
  635. num_codebooks=model.config.num_codebooks,
  636. )
  637. )
  638. logger.info(f"Encoded text: {text}")
  639. # Move temperature, top_p, repetition_penalty to device
  640. # This is important so that changing params doesn't trigger recompile
  641. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  642. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  643. repetition_penalty = torch.tensor(
  644. repetition_penalty, device=device, dtype=torch.float
  645. )
  646. for sample_idx in range(num_samples):
  647. if torch.cuda.is_available():
  648. torch.cuda.synchronize()
  649. global_encoded = []
  650. seg_idx = 0
  651. while seg_idx < len(encoded):
  652. logger.info(
  653. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  654. )
  655. seg = encoded[seg_idx]
  656. global_encoded.append(seg)
  657. lengths = reversed([seg.size(1) for seg in global_encoded])
  658. # Pick last 2000 tokens
  659. count = 0
  660. for i, length in enumerate(lengths):
  661. count += length
  662. if count + length > max_length - 1024 - sum(
  663. t.shape[1] for t in encoded_prompts
  664. ):
  665. break
  666. if i != 0 and i % 2 == 0:
  667. i -= 1
  668. # Rotate the list, always make sure first segment is included to avoid drift
  669. if i < len(global_encoded) - 2:
  670. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  671. else:
  672. partial_encoded = global_encoded
  673. if use_prompt:
  674. partial_encoded = encoded_prompts + partial_encoded
  675. cat_encoded = torch.cat(partial_encoded, dim=1)
  676. prompt_length = cat_encoded.size(1)
  677. t0 = time.perf_counter()
  678. y = generate(
  679. model=model,
  680. prompt=cat_encoded,
  681. max_new_tokens=max_new_tokens,
  682. im_end_id=im_end_id,
  683. decode_one_token=decode_one_token,
  684. temperature=temperature,
  685. top_p=top_p,
  686. repetition_penalty=repetition_penalty,
  687. )
  688. if sample_idx == 0 and seg_idx == 0 and compile:
  689. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  690. if torch.cuda.is_available():
  691. torch.cuda.synchronize()
  692. t = time.perf_counter() - t0
  693. tokens_generated = y.size(1) - prompt_length
  694. tokens_sec = tokens_generated / t
  695. logger.info(
  696. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  697. )
  698. logger.info(
  699. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  700. )
  701. if torch.cuda.is_available():
  702. logger.info(
  703. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  704. )
  705. # Put the generated tokens
  706. # since there is <im_end> and <eos> tokens, we remove last 2 tokens
  707. codes = y[1:, prompt_length:-1].clone()
  708. codes = codes - 1
  709. assert (codes >= 0).all(), f"Negative code found"
  710. decoded = y[:, prompt_length:-1].clone()
  711. # But for global encoding, we should keep the <im_end> token
  712. global_encoded.append(decoded)
  713. assert (codes >= 0).all(), f"Negative code found: {codes}"
  714. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  715. seg_idx += 1
  716. # This indicates the end of the current sample
  717. yield GenerateResponse(action="next")
  718. @dataclass
  719. class WrappedGenerateResponse:
  720. status: Literal["success", "error"]
  721. response: Optional[GenerateResponse | Exception] = None
  722. @dataclass
  723. class GenerateRequest:
  724. request: dict
  725. response_queue: queue.Queue
  726. def launch_thread_safe_queue(
  727. checkpoint_path,
  728. device,
  729. precision,
  730. compile: bool = False,
  731. ):
  732. input_queue = queue.Queue()
  733. init_event = threading.Event()
  734. def worker():
  735. model, decode_one_token = load_model(
  736. checkpoint_path, device, precision, compile=compile
  737. )
  738. with torch.device(device):
  739. model.setup_caches(
  740. max_batch_size=1,
  741. max_seq_len=model.config.max_seq_len,
  742. dtype=next(model.parameters()).dtype,
  743. )
  744. init_event.set()
  745. while True:
  746. item: GenerateRequest | None = input_queue.get()
  747. if item is None:
  748. break
  749. kwargs = item.request
  750. response_queue = item.response_queue
  751. try:
  752. for chunk in generate_long(
  753. model=model, decode_one_token=decode_one_token, **kwargs
  754. ):
  755. response_queue.put(
  756. WrappedGenerateResponse(status="success", response=chunk)
  757. )
  758. except Exception as e:
  759. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  760. threading.Thread(target=worker, daemon=True).start()
  761. init_event.wait()
  762. return input_queue
  763. def launch_thread_safe_queue_agent(
  764. checkpoint_path,
  765. device,
  766. precision,
  767. compile: bool = False,
  768. ):
  769. input_queue = queue.Queue()
  770. init_event = threading.Event()
  771. tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
  772. config = BaseModelArgs.from_pretrained(checkpoint_path)
  773. def worker():
  774. model, decode_one_token = load_model(
  775. checkpoint_path, device, precision, compile=compile, is_agent=True
  776. )
  777. with torch.device(device):
  778. model.setup_caches(
  779. max_batch_size=1,
  780. max_seq_len=model.config.max_seq_len,
  781. dtype=next(model.parameters()).dtype,
  782. )
  783. init_event.set()
  784. while True:
  785. item: GenerateRequest | None = input_queue.get()
  786. if item is None:
  787. break
  788. kwargs = item.request
  789. response_queue = item.response_queue
  790. try:
  791. for token in generate_agent(
  792. model=model,
  793. decode_one_token=decode_one_token,
  794. **kwargs,
  795. ):
  796. response_queue.put(token)
  797. response_queue.put("stop")
  798. except Exception as e:
  799. import traceback
  800. logger.exception(f"Error in worker: {traceback.format_exc()}")
  801. response_queue.put("error")
  802. threading.Thread(target=worker, daemon=True).start()
  803. init_event.wait()
  804. return input_queue, tokenizer, config
  805. @click.command()
  806. @click.option(
  807. "--text",
  808. type=str,
  809. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  810. )
  811. @click.option("--prompt-text", type=str, default=None, multiple=True)
  812. @click.option(
  813. "--prompt-tokens",
  814. type=click.Path(path_type=Path, exists=True),
  815. default=None,
  816. multiple=True,
  817. )
  818. @click.option("--num-samples", type=int, default=1)
  819. @click.option("--max-new-tokens", type=int, default=0)
  820. @click.option("--top-p", type=float, default=0.7)
  821. @click.option("--repetition-penalty", type=float, default=1.2)
  822. @click.option("--temperature", type=float, default=0.7)
  823. @click.option(
  824. "--checkpoint-path",
  825. type=click.Path(path_type=Path, exists=True),
  826. default="checkpoints/fish-speech-1.4",
  827. )
  828. @click.option("--device", type=str, default="cuda")
  829. @click.option("--compile/--no-compile", default=False)
  830. @click.option("--seed", type=int, default=42)
  831. @click.option("--half/--no-half", default=False)
  832. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  833. @click.option("--chunk-length", type=int, default=100)
  834. def main(
  835. text: str,
  836. prompt_text: Optional[list[str]],
  837. prompt_tokens: Optional[list[Path]],
  838. num_samples: int,
  839. max_new_tokens: int,
  840. top_p: int,
  841. repetition_penalty: float,
  842. temperature: float,
  843. checkpoint_path: Path,
  844. device: str,
  845. compile: bool,
  846. seed: int,
  847. half: bool,
  848. iterative_prompt: bool,
  849. chunk_length: int,
  850. ) -> None:
  851. precision = torch.half if half else torch.bfloat16
  852. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  853. raise ValueError(
  854. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  855. )
  856. logger.info("Loading model ...")
  857. t0 = time.time()
  858. model, decode_one_token = load_model(
  859. checkpoint_path, device, precision, compile=compile
  860. )
  861. with torch.device(device):
  862. model.setup_caches(
  863. max_batch_size=1,
  864. max_seq_len=model.config.max_seq_len,
  865. dtype=next(model.parameters()).dtype,
  866. )
  867. if torch.cuda.is_available():
  868. torch.cuda.synchronize()
  869. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  870. if prompt_tokens is not None:
  871. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  872. torch.manual_seed(seed)
  873. if torch.cuda.is_available():
  874. torch.cuda.manual_seed(seed)
  875. generator = generate_long(
  876. model=model,
  877. device=device,
  878. decode_one_token=decode_one_token,
  879. text=text,
  880. num_samples=num_samples,
  881. max_new_tokens=max_new_tokens,
  882. top_p=top_p,
  883. repetition_penalty=repetition_penalty,
  884. temperature=temperature,
  885. compile=compile,
  886. iterative_prompt=iterative_prompt,
  887. chunk_length=chunk_length,
  888. prompt_text=prompt_text,
  889. prompt_tokens=prompt_tokens,
  890. )
  891. idx = 0
  892. codes = []
  893. for response in generator:
  894. if response.action == "sample":
  895. codes.append(response.codes)
  896. logger.info(f"Sampled text: {response.text}")
  897. elif response.action == "next":
  898. if codes:
  899. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  900. logger.info(f"Saved codes to codes_{idx}.npy")
  901. logger.info(f"Next sample")
  902. codes = []
  903. idx += 1
  904. else:
  905. logger.error(f"Error: {response}")
  906. if __name__ == "__main__":
  907. main()