inference.py 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import numpy as np
  11. import torch
  12. import torch._dynamo.config
  13. import torch._inductor.config
  14. from loguru import logger
  15. from tqdm import tqdm
  16. from transformers import AutoTokenizer
  17. from fish_speech.conversation import (
  18. CODEBOOK_PAD_TOKEN_ID,
  19. Conversation,
  20. Message,
  21. TextPart,
  22. VQPart,
  23. )
  24. from fish_speech.models.text2semantic.llama import BaseModelArgs
  25. from fish_speech.text import clean_text, split_text
  26. from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
  27. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  28. torch._inductor.config.coordinate_descent_tuning = True
  29. torch._inductor.config.triton.unique_kernel_names = True
  30. if hasattr(torch._inductor.config, "fx_graph_cache"):
  31. # Experimental feature to reduce compilation times, will be on by default in future
  32. torch._inductor.config.fx_graph_cache = True
  33. from torch.nn.attention import SDPBackend, sdpa_kernel
  34. from fish_speech.models.text2semantic.llama import (
  35. BaseTransformer,
  36. DualARTransformer,
  37. NaiveTransformer,
  38. )
  39. def multinomial_sample_one_no_sync(
  40. probs_sort,
  41. ): # Does multinomial sampling without a cuda synchronization
  42. q = torch.empty_like(probs_sort).exponential_(1)
  43. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  44. def logits_to_probs(
  45. logits,
  46. previous_tokens: Optional[torch.Tensor] = None,
  47. temperature: torch.Tensor = 1.0,
  48. top_p: torch.Tensor = 1.0,
  49. repetition_penalty: torch.Tensor = 1.0,
  50. ) -> torch.Tensor:
  51. # Apply repetition penalty
  52. if previous_tokens is not None:
  53. previous_tokens = previous_tokens.long()
  54. score = torch.gather(logits, dim=0, index=previous_tokens)
  55. score = torch.where(
  56. score < 0, score * repetition_penalty, score / repetition_penalty
  57. )
  58. logits.scatter_(dim=0, index=previous_tokens, src=score)
  59. # Apply top-p sampling
  60. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  61. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  62. sorted_indices_to_remove = cum_probs > top_p
  63. sorted_indices_to_remove[0] = False # keep at least one option
  64. indices_to_remove = sorted_indices_to_remove.scatter(
  65. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  66. )
  67. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  68. logits = logits / max(temperature, 1e-5)
  69. probs = torch.nn.functional.softmax(logits, dim=-1)
  70. return probs
  71. def multinomial_sample_one_no_sync_agent(
  72. probs_sort,
  73. ): # Does multinomial sampling without a cuda synchronization
  74. q = torch.empty_like(probs_sort).exponential_(1)
  75. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  76. def logits_to_probs_agent(
  77. logits,
  78. previous_tokens: Optional[torch.Tensor] = None,
  79. temperature: torch.Tensor = 1.0,
  80. top_p: torch.Tensor = 1.0,
  81. repetition_penalty: torch.Tensor = 1.0,
  82. ) -> torch.Tensor:
  83. # Apply repetition penalty
  84. if previous_tokens is not None:
  85. previous_tokens = previous_tokens.long()
  86. score = torch.gather(logits, dim=-1, index=previous_tokens)
  87. score = torch.where(
  88. score < 0, score * repetition_penalty, score / repetition_penalty
  89. )
  90. logits.scatter_(dim=-1, index=previous_tokens, src=score)
  91. # Apply top-p sampling
  92. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  93. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  94. sorted_indices_to_remove = cum_probs > top_p
  95. sorted_indices_to_remove[..., 0] = False # keep at least one option
  96. indices_to_remove = sorted_indices_to_remove.scatter(
  97. dim=-1, index=sorted_indices, src=sorted_indices_to_remove
  98. )
  99. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  100. logits = logits / max(temperature, 1e-5)
  101. probs = torch.nn.functional.softmax(logits, dim=-1)
  102. return probs
  103. def sample(
  104. logits,
  105. previous_tokens: Optional[torch.Tensor] = None,
  106. **sampling_kwargs,
  107. ) -> Tuple[torch.Tensor, torch.Tensor]:
  108. probs = logits_to_probs(
  109. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  110. )
  111. idx_next = multinomial_sample_one_no_sync(probs)
  112. return idx_next, probs
  113. def sample_agent(
  114. logits,
  115. previous_tokens: Optional[torch.Tensor] = None,
  116. **sampling_kwargs,
  117. ) -> Tuple[torch.Tensor, torch.Tensor]:
  118. probs = logits_to_probs_agent(
  119. logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
  120. )
  121. idx_next = multinomial_sample_one_no_sync_agent(probs)
  122. return idx_next, probs
  123. def decode_one_token_ar_agent(
  124. model: DualARTransformer,
  125. x: torch.Tensor,
  126. input_pos: torch.Tensor,
  127. semantic_ids: list,
  128. previous_tokens: torch.Tensor = None,
  129. **sampling_kwargs,
  130. ) -> torch.Tensor:
  131. # print(x, input_pos)
  132. x = model.forward_generate(x, input_pos)
  133. logits = x.logits # [:, -1:]
  134. hidden_states = x.hidden_states # [:, -1:]
  135. sampling_kwargs_main = sampling_kwargs.copy()
  136. sampling_kwargs_main["temperature"] = 0.1
  137. sampling_kwargs_main["top_p"] = 0.1
  138. sampling_kwargs_main["repetition_penalty"] = 1.0
  139. codebooks = [
  140. sample_agent(
  141. logits,
  142. previous_tokens=None, # Disable repetition penalty for the token codebook
  143. **sampling_kwargs_main,
  144. )[0]
  145. ]
  146. # Cleanup the cache
  147. for layer in model.fast_layers:
  148. layer.attention.kv_cache.k_cache.fill_(0)
  149. layer.attention.kv_cache.v_cache.fill_(0)
  150. for codebook_idx in range(model.config.num_codebooks):
  151. input_pos = torch.tensor(
  152. [codebook_idx], device=hidden_states.device, dtype=torch.long
  153. )
  154. logits = model.forward_generate_fast(hidden_states, input_pos)
  155. a = sample_agent(
  156. logits,
  157. previous_tokens=(
  158. previous_tokens[:, codebook_idx + 1]
  159. if previous_tokens is not None
  160. else None
  161. ),
  162. **sampling_kwargs,
  163. )[0]
  164. hidden_states = model.fast_embeddings(a)
  165. codebooks.append(a)
  166. codebooks = torch.stack(codebooks, dim=1)
  167. semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
  168. codebooks[:, 1:, :] = torch.masked_fill(
  169. codebooks[:, 1:, :],
  170. ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
  171. CODEBOOK_PAD_TOKEN_ID,
  172. )
  173. return codebooks
  174. def decode_one_token_naive_agent(
  175. model: NaiveTransformer,
  176. x: torch.Tensor,
  177. input_pos: torch.Tensor,
  178. semantic_ids: list,
  179. previous_tokens: torch.Tensor = None,
  180. **sampling_kwargs,
  181. ) -> torch.Tensor:
  182. x = model.forward_generate(x, input_pos)
  183. codebooks = [
  184. sample(
  185. x.token_logits,
  186. previous_tokens=None, # Disable repetition penalty for the token codebook
  187. **sampling_kwargs,
  188. )[0]
  189. ]
  190. for i in range(model.config.num_codebooks):
  191. codebooks.append(
  192. sample_agent(
  193. x.codebook_logits[:, :, i],
  194. previous_tokens=(
  195. previous_tokens[:, i + 1] if previous_tokens is not None else None
  196. ),
  197. **sampling_kwargs,
  198. )[0]
  199. )
  200. codebooks = torch.stack(codebooks, dim=1)
  201. semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
  202. codebooks[:, 1:, :] = torch.masked_fill(
  203. codebooks[:, 1:, :],
  204. ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
  205. CODEBOOK_PAD_TOKEN_ID,
  206. )
  207. return codebooks
  208. def decode_one_token_ar(
  209. model: DualARTransformer,
  210. x: torch.Tensor,
  211. input_pos: torch.Tensor,
  212. semantic_ids: list,
  213. previous_tokens: torch.Tensor = None,
  214. **sampling_kwargs,
  215. ) -> torch.Tensor:
  216. x = model.forward_generate(x, input_pos)
  217. sampling_kwargs_main = sampling_kwargs.copy()
  218. # sampling_kwargs_main["temperature"] = 0.1
  219. # sampling_kwargs_main["top_p"] = 0.1
  220. # sampling_kwargs_main["repetition_penalty"] = 1.0
  221. codebooks = [
  222. sample(
  223. x.logits,
  224. previous_tokens=(
  225. previous_tokens[0] if previous_tokens is not None else None
  226. ), # Disable repetition penalty for the token codebook
  227. **sampling_kwargs_main,
  228. )[0]
  229. ]
  230. hidden_states = x.hidden_states
  231. # Cleanup the cache
  232. for layer in model.fast_layers:
  233. layer.attention.kv_cache.k_cache.fill_(0)
  234. layer.attention.kv_cache.v_cache.fill_(0)
  235. input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
  236. model.forward_generate_fast(hidden_states, input_pos)
  237. a = codebooks[0] - model.tokenizer.semantic_begin_id
  238. a[a < 0] = 0
  239. hidden_states = model.fast_embeddings(a)
  240. codebooks.append(a)
  241. for codebook_idx in range(1, model.config.num_codebooks):
  242. input_pos = torch.tensor(
  243. [codebook_idx], device=hidden_states.device, dtype=torch.long
  244. )
  245. logits = model.forward_generate_fast(hidden_states, input_pos)
  246. a = sample(
  247. logits,
  248. previous_tokens=(
  249. previous_tokens[codebook_idx + 1]
  250. if previous_tokens is not None
  251. else None
  252. ),
  253. **sampling_kwargs,
  254. )[0]
  255. hidden_states = model.fast_embeddings(a)
  256. codebooks.append(a)
  257. codebooks = torch.stack(codebooks, dim=0)
  258. # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
  259. # codebooks[1:, :] = torch.masked_fill(
  260. # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
  261. # )
  262. # print(codebooks)
  263. return codebooks
  264. def decode_one_token_naive(
  265. model: NaiveTransformer,
  266. x: torch.Tensor,
  267. input_pos: torch.Tensor,
  268. previous_tokens: torch.Tensor = None,
  269. **sampling_kwargs,
  270. ) -> torch.Tensor:
  271. x = model.forward_generate(x, input_pos)
  272. sampling_kwargs_main = sampling_kwargs.copy()
  273. sampling_kwargs_main["temperature"] = 0.1
  274. sampling_kwargs_main["top_p"] = 0.1
  275. sampling_kwargs_main["repetition_penalty"] = 1.0
  276. codebooks = [
  277. sample(
  278. x.logits,
  279. previous_tokens=None, # Disable repetition penalty for the token codebook
  280. **sampling_kwargs_main,
  281. )[0]
  282. ]
  283. for i in range(model.config.num_codebooks):
  284. codebooks.append(
  285. sample(
  286. x.codebook_logits[:, :, i],
  287. previous_tokens=(
  288. previous_tokens[i + 1] if previous_tokens is not None else None
  289. ),
  290. **sampling_kwargs,
  291. )[0]
  292. )
  293. return torch.stack(codebooks, dim=0)
  294. def decode_n_tokens(
  295. model: NaiveTransformer,
  296. cur_token: torch.Tensor,
  297. input_pos: torch.Tensor,
  298. num_new_tokens: int,
  299. semantic_ids: list,
  300. decode_one_token=decode_one_token_naive,
  301. **sampling_kwargs,
  302. ):
  303. previous_tokens = torch.zeros(
  304. (model.config.num_codebooks + 1, model.config.max_seq_len),
  305. dtype=torch.int,
  306. device=cur_token.device,
  307. )
  308. for i in tqdm(range(num_new_tokens)):
  309. # We need to get windowed repeat penalty
  310. win_size = 16
  311. if i < win_size:
  312. window = previous_tokens[:, :win_size]
  313. else:
  314. window = previous_tokens[:, i - win_size : i]
  315. with (
  316. torch.backends.cuda.sdp_kernel(
  317. enable_flash=False, enable_mem_efficient=False, enable_math=True
  318. )
  319. if torch.cuda.is_available()
  320. else nullcontext()
  321. ): # Actually better for Inductor to codegen attention here
  322. next_token = decode_one_token(
  323. model=model,
  324. x=cur_token,
  325. input_pos=input_pos,
  326. previous_tokens=window,
  327. semantic_ids=semantic_ids,
  328. **sampling_kwargs,
  329. )
  330. input_pos += 1
  331. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  332. previous_tokens[:, i : i + 1] = next_token.view(
  333. model.config.num_codebooks + 1, -1
  334. )
  335. if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
  336. break
  337. return previous_tokens[:, : i + 1]
  338. @torch.no_grad()
  339. @torch.inference_mode()
  340. def generate(
  341. *,
  342. model: NaiveTransformer,
  343. prompt: torch.Tensor,
  344. max_new_tokens: int,
  345. decode_one_token=decode_one_token_naive,
  346. **sampling_kwargs,
  347. ) -> torch.Tensor:
  348. """
  349. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  350. """
  351. # create an empty tensor of the expected final shape and fill in the current tokens
  352. T = prompt.size(1)
  353. # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
  354. semantic_ids = [
  355. model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
  356. ]
  357. if max_new_tokens:
  358. if T + max_new_tokens > model.config.max_seq_len:
  359. max_new_tokens = model.config.max_seq_len - T
  360. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  361. T_new = T + max_new_tokens
  362. else:
  363. T_new = model.config.max_seq_len
  364. max_new_tokens = T_new - T
  365. device, dtype = prompt.device, prompt.dtype
  366. codebook_dim = 1 + model.config.num_codebooks
  367. # create an empty tensor of the expected final shape and fill in the current tokens
  368. empty = torch.empty(
  369. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  370. )
  371. empty[:, :T] = prompt
  372. seq = empty
  373. input_pos = torch.arange(0, T, device=device)
  374. # Use non-accelerated version for now, to avoid compilation overhead
  375. prefill_decode = (
  376. decode_one_token_naive
  377. if isinstance(model, NaiveTransformer)
  378. else decode_one_token_ar
  379. )
  380. next_token = prefill_decode(
  381. model,
  382. prompt.view(1, codebook_dim, -1),
  383. input_pos,
  384. semantic_ids=semantic_ids,
  385. **sampling_kwargs,
  386. )
  387. seq[:, T : T + 1] = next_token
  388. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  389. x = decode_n_tokens(
  390. model,
  391. next_token.view(1, codebook_dim, -1),
  392. input_pos,
  393. max_new_tokens - 1,
  394. decode_one_token=decode_one_token,
  395. semantic_ids=semantic_ids,
  396. **sampling_kwargs,
  397. )
  398. # x = torch.cat(generated_tokens, dim=1)
  399. seq = seq[:, : T + 1 + x.size(1)]
  400. seq[:, T + 1 :] = x
  401. return seq
  402. def decode_n_tokens_agent(
  403. model: NaiveTransformer,
  404. cur_token: torch.Tensor,
  405. input_pos: torch.Tensor,
  406. num_new_tokens: int,
  407. semantic_ids: list,
  408. im_end_id: int = 4,
  409. decode_one_token=decode_one_token_naive_agent,
  410. early_stop_threshold: float = 0.6,
  411. **sampling_kwargs,
  412. ):
  413. batch_size = cur_token.size(0)
  414. previous_tokens = torch.zeros(
  415. (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
  416. dtype=torch.int,
  417. device=cur_token.device,
  418. )
  419. finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
  420. finished = finished | (cur_token[:, 0, -1] == im_end_id)
  421. start_time = time.time()
  422. for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
  423. # We need to get windowed repeat penalty
  424. win_size = 16
  425. if i < win_size:
  426. window = previous_tokens[:, :, :win_size]
  427. else:
  428. window = previous_tokens[:, :, i - win_size : i]
  429. with sdpa_kernel(
  430. SDPBackend.MATH
  431. ): # Actually better for Inductor to codegen attention here
  432. next_token = decode_one_token(
  433. model=model,
  434. x=cur_token,
  435. input_pos=input_pos,
  436. previous_tokens=window,
  437. semantic_ids=semantic_ids,
  438. **sampling_kwargs,
  439. )
  440. input_pos += 1
  441. cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
  442. previous_tokens[:, :, i : i + 1] = next_token.view(
  443. batch_size, model.config.num_codebooks + 1, -1
  444. )
  445. yield cur_token.cpu()
  446. finished = finished | (cur_token[:, 0, -1] == im_end_id)
  447. if finished.all() or (
  448. 0 < early_stop_threshold < 1
  449. and finished.sum() >= round(batch_size * early_stop_threshold)
  450. ):
  451. break
  452. total_time = time.time() - start_time
  453. generated_tokens = i + 1
  454. tokens_per_second = (generated_tokens / total_time) * batch_size
  455. logger.info(
  456. f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
  457. )
  458. @torch.no_grad()
  459. @torch.inference_mode()
  460. def generate_agent(
  461. *,
  462. model: BaseTransformer,
  463. prompt: torch.Tensor,
  464. max_new_tokens: int,
  465. semantic_ids: list,
  466. im_end_id: int = 4,
  467. decode_one_token=decode_one_token_naive_agent,
  468. num_samples: int = 1,
  469. early_stop_threshold: float = 0.6,
  470. **sampling_kwargs,
  471. ):
  472. """
  473. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  474. """
  475. # create an empty tensor of the expected final shape and fill in the current tokens
  476. T = prompt.size(1)
  477. prompt = prompt[None].repeat(num_samples, 1, 1)
  478. if T >= model.config.max_seq_len:
  479. raise ValueError(
  480. f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
  481. )
  482. if max_new_tokens:
  483. if T + max_new_tokens > model.config.max_seq_len:
  484. max_new_tokens = model.config.max_seq_len - T
  485. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  486. T_new = T + max_new_tokens
  487. else:
  488. T_new = model.config.max_seq_len
  489. max_new_tokens = T_new - T
  490. device, dtype = prompt.device, prompt.dtype
  491. codebook_dim = 1 + model.config.num_codebooks
  492. input_pos = torch.arange(0, T, device=device)
  493. # Use non-accelerated version for now, to avoid compilation overhead
  494. prefill_decode = (
  495. decode_one_token_naive_agent
  496. if isinstance(model, NaiveTransformer)
  497. else decode_one_token_ar_agent
  498. )
  499. next_token = prefill_decode(
  500. model,
  501. prompt,
  502. input_pos,
  503. semantic_ids=semantic_ids,
  504. **sampling_kwargs,
  505. ).view(num_samples, codebook_dim, -1)
  506. yield next_token.cpu()
  507. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  508. yield from decode_n_tokens_agent(
  509. model,
  510. next_token,
  511. input_pos,
  512. max_new_tokens - 1,
  513. im_end_id=im_end_id,
  514. semantic_ids=semantic_ids,
  515. decode_one_token=decode_one_token,
  516. early_stop_threshold=early_stop_threshold,
  517. **sampling_kwargs,
  518. )
  519. def encode_tokens(
  520. tokenizer,
  521. string,
  522. device="cuda",
  523. prompt_tokens=None,
  524. num_codebooks=4,
  525. ):
  526. string = clean_text(string)
  527. messages = []
  528. messages.append(
  529. Message(
  530. role="user",
  531. parts=[TextPart(text=string)],
  532. cal_loss=False,
  533. )
  534. )
  535. if prompt_tokens is not None:
  536. if prompt_tokens.ndim == 3:
  537. assert (
  538. prompt_tokens.shape[0] == 1
  539. ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
  540. prompt_tokens = prompt_tokens[0]
  541. assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
  542. if prompt_tokens.shape[0] > num_codebooks:
  543. logger.warning(
  544. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  545. )
  546. prompt_tokens = prompt_tokens[:num_codebooks]
  547. vq_part = VQPart(codes=prompt_tokens.to(device))
  548. messages.append(
  549. Message(
  550. role="assistant",
  551. parts=[TextPart(text="<|voice|>"), vq_part],
  552. cal_loss=False,
  553. )
  554. )
  555. else:
  556. messages.append(
  557. Message(
  558. role="assistant",
  559. parts=[TextPart(text="<|voice|>")],
  560. cal_loss=False,
  561. add_im_end=False,
  562. )
  563. )
  564. conversation = Conversation(messages=messages)
  565. # conversation.visualize(tokenizer)
  566. encoded = conversation.encode_for_inference(
  567. tokenizer=tokenizer,
  568. num_codebooks=num_codebooks,
  569. )
  570. return encoded.to(device)
  571. def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
  572. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  573. checkpoint_path, load_weights=True, is_agent=is_agent
  574. )
  575. model = model.to(device=device, dtype=precision)
  576. logger.info(f"Restored model from checkpoint")
  577. if isinstance(model, DualARTransformer):
  578. decode_one_token = (
  579. decode_one_token_ar_agent if is_agent else decode_one_token_ar
  580. )
  581. logger.info("Using DualARTransformer")
  582. else:
  583. decode_one_token = (
  584. decode_one_token_naive_agent if is_agent else decode_one_token_naive
  585. )
  586. logger.info("Using NaiveTransformer")
  587. if compile:
  588. logger.info("Compiling function...")
  589. decode_one_token = torch.compile(
  590. decode_one_token,
  591. fullgraph=True,
  592. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  593. mode="reduce-overhead" if torch.cuda.is_available() else None,
  594. )
  595. return model.eval(), decode_one_token
  596. @dataclass
  597. class GenerateResponse:
  598. action: Literal["sample", "next"]
  599. codes: Optional[torch.Tensor] = None
  600. text: Optional[str] = None
  601. def generate_long(
  602. *,
  603. model,
  604. device: str | torch.device,
  605. decode_one_token: callable,
  606. text: str,
  607. num_samples: int = 1,
  608. max_new_tokens: int = 0,
  609. top_p: int = 0.7,
  610. repetition_penalty: float = 1.5,
  611. temperature: float = 0.7,
  612. compile: bool = False,
  613. iterative_prompt: bool = True,
  614. max_length: int = 2048,
  615. chunk_length: int = 150,
  616. prompt_text: Optional[str | list[str]] = None,
  617. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  618. ):
  619. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  620. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  621. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  622. use_prompt = prompt_text is not None and prompt_tokens is not None
  623. if use_prompt and isinstance(prompt_text, str):
  624. prompt_text = [prompt_text]
  625. prompt_tokens = [prompt_tokens]
  626. assert use_prompt is False or len(prompt_text) == len(
  627. prompt_tokens
  628. ), "Prompt text and tokens must have the same length"
  629. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  630. tokenizer = model.tokenizer
  631. im_end_id = tokenizer.get_token_id("<|im_end|>")
  632. encoded = []
  633. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  634. encoded_prompts = [
  635. Conversation(
  636. messages=[
  637. Message(
  638. role="system",
  639. parts=[TextPart(text="Speak out the provided text.")],
  640. cal_loss=False,
  641. )
  642. ]
  643. )
  644. .encode_for_inference(
  645. tokenizer=tokenizer,
  646. num_codebooks=model.config.num_codebooks,
  647. )
  648. .to(device)
  649. ]
  650. if use_prompt:
  651. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  652. encoded_prompts.append(
  653. encode_tokens(
  654. tokenizer,
  655. string=t,
  656. device=device,
  657. prompt_tokens=c,
  658. num_codebooks=model.config.num_codebooks,
  659. )
  660. )
  661. for idx, text in enumerate(texts):
  662. encoded.append(
  663. encode_tokens(
  664. tokenizer,
  665. string=text,
  666. device=device,
  667. num_codebooks=model.config.num_codebooks,
  668. )
  669. )
  670. logger.info(f"Encoded text: {text}")
  671. # Move temperature, top_p, repetition_penalty to device
  672. # This is important so that changing params doesn't trigger recompile
  673. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  674. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  675. repetition_penalty = torch.tensor(
  676. repetition_penalty, device=device, dtype=torch.float
  677. )
  678. for sample_idx in range(num_samples):
  679. if torch.cuda.is_available():
  680. torch.cuda.synchronize()
  681. global_encoded = []
  682. seg_idx = 0
  683. while seg_idx < len(encoded):
  684. logger.info(
  685. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  686. )
  687. seg = encoded[seg_idx]
  688. global_encoded.append(seg)
  689. lengths = reversed([seg.size(1) for seg in global_encoded])
  690. # Pick last 2000 tokens
  691. count = 0
  692. for i, length in enumerate(lengths):
  693. count += length
  694. if count + length > max_length - 1024 - sum(
  695. t.shape[1] for t in encoded_prompts
  696. ):
  697. break
  698. if i != 0 and i % 2 == 0:
  699. i -= 1
  700. # Rotate the list, always make sure first segment is included to avoid drift
  701. if i < len(global_encoded) - 2:
  702. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  703. else:
  704. partial_encoded = global_encoded
  705. if use_prompt:
  706. partial_encoded = encoded_prompts + partial_encoded
  707. cat_encoded = torch.cat(partial_encoded, dim=1)
  708. prompt_length = cat_encoded.size(1)
  709. t0 = time.perf_counter()
  710. y = generate(
  711. model=model,
  712. prompt=cat_encoded,
  713. max_new_tokens=max_new_tokens,
  714. decode_one_token=decode_one_token,
  715. temperature=temperature,
  716. top_p=top_p,
  717. repetition_penalty=repetition_penalty,
  718. )
  719. if sample_idx == 0 and seg_idx == 0 and compile:
  720. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  721. if torch.cuda.is_available():
  722. torch.cuda.synchronize()
  723. t = time.perf_counter() - t0
  724. tokens_generated = y.size(1) - prompt_length
  725. tokens_sec = tokens_generated / t
  726. logger.info(
  727. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  728. )
  729. logger.info(
  730. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  731. )
  732. if torch.cuda.is_available():
  733. logger.info(
  734. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  735. )
  736. # Put the generated tokens
  737. # since there is <im_end>, we remove last token
  738. codes = y[1:, prompt_length + 1 :].clone()
  739. assert (codes >= 0).all(), f"Negative code found"
  740. decoded = y[:, prompt_length:].clone()
  741. # But for global encoding, we should keep the <im_end> token
  742. global_encoded.append(decoded)
  743. assert (codes >= 0).all(), f"Negative code found: {codes}"
  744. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  745. seg_idx += 1
  746. # This indicates the end of the current sample
  747. yield GenerateResponse(action="next")
  748. @dataclass
  749. class WrappedGenerateResponse:
  750. status: Literal["success", "error"]
  751. response: Optional[GenerateResponse | Exception] = None
  752. @dataclass
  753. class GenerateRequest:
  754. request: dict
  755. response_queue: queue.Queue
  756. def launch_thread_safe_queue(
  757. checkpoint_path,
  758. device,
  759. precision,
  760. compile: bool = False,
  761. ):
  762. input_queue = queue.Queue()
  763. init_event = threading.Event()
  764. def worker():
  765. model, decode_one_token = load_model(
  766. checkpoint_path, device, precision, compile=compile
  767. )
  768. with torch.device(device):
  769. model.setup_caches(
  770. max_batch_size=1,
  771. max_seq_len=model.config.max_seq_len,
  772. dtype=next(model.parameters()).dtype,
  773. )
  774. init_event.set()
  775. while True:
  776. item: GenerateRequest | None = input_queue.get()
  777. if item is None:
  778. break
  779. kwargs = item.request
  780. response_queue = item.response_queue
  781. try:
  782. for chunk in generate_long(
  783. model=model, decode_one_token=decode_one_token, **kwargs
  784. ):
  785. response_queue.put(
  786. WrappedGenerateResponse(status="success", response=chunk)
  787. )
  788. except Exception as e:
  789. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  790. threading.Thread(target=worker, daemon=True).start()
  791. init_event.wait()
  792. return input_queue
  793. def launch_thread_safe_queue_agent(
  794. checkpoint_path,
  795. device,
  796. precision,
  797. compile: bool = False,
  798. ):
  799. input_queue = queue.Queue()
  800. init_event = threading.Event()
  801. tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
  802. config = BaseModelArgs.from_pretrained(checkpoint_path)
  803. def worker():
  804. model, decode_one_token = load_model(
  805. checkpoint_path, device, precision, compile=compile, is_agent=True
  806. )
  807. with torch.device(device):
  808. model.setup_caches(
  809. max_batch_size=1,
  810. max_seq_len=model.config.max_seq_len,
  811. dtype=next(model.parameters()).dtype,
  812. )
  813. init_event.set()
  814. while True:
  815. item: GenerateRequest | None = input_queue.get()
  816. if item is None:
  817. break
  818. kwargs = item.request
  819. response_queue = item.response_queue
  820. try:
  821. for token in generate_agent(
  822. model=model,
  823. decode_one_token=decode_one_token,
  824. **kwargs,
  825. ):
  826. response_queue.put(token)
  827. response_queue.put("stop")
  828. except Exception as e:
  829. import traceback
  830. logger.exception(f"Error in worker: {traceback.format_exc()}")
  831. response_queue.put("error")
  832. threading.Thread(target=worker, daemon=True).start()
  833. init_event.wait()
  834. return input_queue, tokenizer, config
  835. @click.command()
  836. @click.option(
  837. "--text",
  838. type=str,
  839. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  840. )
  841. @click.option("--prompt-text", type=str, default=None, multiple=True)
  842. @click.option(
  843. "--prompt-tokens",
  844. type=click.Path(path_type=Path, exists=True),
  845. default=None,
  846. multiple=True,
  847. )
  848. @click.option("--num-samples", type=int, default=1)
  849. @click.option("--max-new-tokens", type=int, default=0)
  850. @click.option("--top-p", type=float, default=0.7)
  851. @click.option("--repetition-penalty", type=float, default=1.2)
  852. @click.option("--temperature", type=float, default=0.7)
  853. @click.option(
  854. "--checkpoint-path",
  855. type=click.Path(path_type=Path, exists=True),
  856. default="checkpoints/fish-speech-1.5",
  857. )
  858. @click.option("--device", type=str, default="cuda")
  859. @click.option("--compile/--no-compile", default=False)
  860. @click.option("--seed", type=int, default=42)
  861. @click.option("--half/--no-half", default=False)
  862. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  863. @click.option("--chunk-length", type=int, default=100)
  864. @click.option("--output-dir", type=Path, default="temp")
  865. def main(
  866. text: str,
  867. prompt_text: Optional[list[str]],
  868. prompt_tokens: Optional[list[Path]],
  869. num_samples: int,
  870. max_new_tokens: int,
  871. top_p: int,
  872. repetition_penalty: float,
  873. temperature: float,
  874. checkpoint_path: Path,
  875. device: str,
  876. compile: bool,
  877. seed: int,
  878. half: bool,
  879. iterative_prompt: bool,
  880. chunk_length: int,
  881. output_dir: Path,
  882. ) -> None:
  883. os.makedirs(output_dir, exist_ok=True)
  884. precision = torch.half if half else torch.bfloat16
  885. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  886. raise ValueError(
  887. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  888. )
  889. logger.info("Loading model ...")
  890. t0 = time.time()
  891. model, decode_one_token = load_model(
  892. checkpoint_path, device, precision, compile=compile
  893. )
  894. with torch.device(device):
  895. model.setup_caches(
  896. max_batch_size=1,
  897. max_seq_len=model.config.max_seq_len,
  898. dtype=next(model.parameters()).dtype,
  899. )
  900. if torch.cuda.is_available():
  901. torch.cuda.synchronize()
  902. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  903. if prompt_tokens is not None:
  904. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  905. torch.manual_seed(seed)
  906. if torch.cuda.is_available():
  907. torch.cuda.manual_seed(seed)
  908. generator = generate_long(
  909. model=model,
  910. device=device,
  911. decode_one_token=decode_one_token,
  912. text=text,
  913. num_samples=num_samples,
  914. max_new_tokens=max_new_tokens,
  915. top_p=top_p,
  916. repetition_penalty=repetition_penalty,
  917. temperature=temperature,
  918. compile=compile,
  919. iterative_prompt=iterative_prompt,
  920. chunk_length=chunk_length,
  921. prompt_text=prompt_text,
  922. prompt_tokens=prompt_tokens,
  923. )
  924. idx = 0
  925. codes = []
  926. for response in generator:
  927. if response.action == "sample":
  928. codes.append(response.codes)
  929. logger.info(f"Sampled text: {response.text}")
  930. elif response.action == "next":
  931. if codes:
  932. codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
  933. np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
  934. logger.info(f"Saved codes to {codes_npy_path}")
  935. logger.info(f"Next sample")
  936. codes = []
  937. idx += 1
  938. else:
  939. logger.error(f"Error: {response}")
  940. if __name__ == "__main__":
  941. main()