inference.py 34 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import numpy as np
  11. import torch
  12. import torch._dynamo.config
  13. import torch._inductor.config
  14. from loguru import logger
  15. from torch.nn.attention import SDPBackend, sdpa_kernel
  16. from tqdm import tqdm
  17. from transformers import AutoTokenizer
  18. from fish_speech.conversation import (
  19. CODEBOOK_PAD_TOKEN_ID,
  20. Conversation,
  21. Message,
  22. TextPart,
  23. VQPart,
  24. )
  25. from fish_speech.models.text2semantic.llama import (
  26. BaseModelArgs,
  27. BaseTransformer,
  28. DualARTransformer,
  29. NaiveTransformer,
  30. )
  31. from fish_speech.text import clean_text, split_text
  32. from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
  33. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  34. torch._inductor.config.coordinate_descent_tuning = True
  35. torch._inductor.config.triton.unique_kernel_names = True
  36. if hasattr(torch._inductor.config, "fx_graph_cache"):
  37. # Experimental feature to reduce compilation times, will be on by default in future
  38. torch._inductor.config.fx_graph_cache = True
  39. def multinomial_sample_one_no_sync(
  40. probs_sort,
  41. ): # Does multinomial sampling without a cuda synchronization
  42. q = torch.empty_like(probs_sort).exponential_(1)
  43. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  44. def logits_to_probs(
  45. logits,
  46. previous_tokens: Optional[torch.Tensor] = None,
  47. temperature: torch.Tensor = 1.0,
  48. top_p: torch.Tensor = 1.0,
  49. repetition_penalty: torch.Tensor = 1.0,
  50. ) -> torch.Tensor:
  51. # Apply repetition penalty
  52. if previous_tokens is not None:
  53. previous_tokens = previous_tokens.long()
  54. score = torch.gather(logits, dim=0, index=previous_tokens)
  55. score = torch.where(
  56. score < 0, score * repetition_penalty, score / repetition_penalty
  57. )
  58. logits.scatter_(dim=0, index=previous_tokens, src=score)
  59. # Apply top-p sampling
  60. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  61. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  62. sorted_indices_to_remove = cum_probs > top_p
  63. sorted_indices_to_remove[0] = False # keep at least one option
  64. indices_to_remove = sorted_indices_to_remove.scatter(
  65. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  66. )
  67. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  68. logits = logits / max(temperature, 1e-5)
  69. probs = torch.nn.functional.softmax(logits, dim=-1)
  70. return probs
  71. def multinomial_sample_one_no_sync_agent(
  72. probs_sort,
  73. ): # Does multinomial sampling without a cuda synchronization
  74. q = torch.empty_like(probs_sort).exponential_(1)
  75. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  76. def logits_to_probs_agent(
  77. logits,
  78. previous_tokens: Optional[torch.Tensor] = None,
  79. temperature: torch.Tensor = 1.0,
  80. top_p: torch.Tensor = 1.0,
  81. repetition_penalty: torch.Tensor = 1.0,
  82. ) -> torch.Tensor:
  83. # Apply repetition penalty
  84. if previous_tokens is not None:
  85. previous_tokens = previous_tokens.long()
  86. score = torch.gather(logits, dim=-1, index=previous_tokens)
  87. score = torch.where(
  88. score < 0, score * repetition_penalty, score / repetition_penalty
  89. )
  90. logits.scatter_(dim=-1, index=previous_tokens, src=score)
  91. # Apply top-p sampling
  92. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  93. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  94. sorted_indices_to_remove = cum_probs > top_p
  95. sorted_indices_to_remove[..., 0] = False # keep at least one option
  96. indices_to_remove = sorted_indices_to_remove.scatter(
  97. dim=-1, index=sorted_indices, src=sorted_indices_to_remove
  98. )
  99. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  100. logits = logits / max(temperature, 1e-5)
  101. probs = torch.nn.functional.softmax(logits, dim=-1)
  102. return probs
  103. def sample(
  104. logits,
  105. previous_tokens: Optional[torch.Tensor] = None,
  106. **sampling_kwargs,
  107. ) -> Tuple[torch.Tensor, torch.Tensor]:
  108. probs = logits_to_probs(
  109. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  110. )
  111. idx_next = multinomial_sample_one_no_sync(probs)
  112. return idx_next, probs
  113. def sample_agent(
  114. logits,
  115. previous_tokens: Optional[torch.Tensor] = None,
  116. **sampling_kwargs,
  117. ) -> Tuple[torch.Tensor, torch.Tensor]:
  118. probs = logits_to_probs_agent(
  119. logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
  120. )
  121. idx_next = multinomial_sample_one_no_sync_agent(probs)
  122. return idx_next, probs
  123. def decode_one_token_ar_agent(
  124. model: DualARTransformer,
  125. x: torch.Tensor,
  126. input_pos: torch.Tensor,
  127. semantic_ids: list,
  128. previous_tokens: torch.Tensor = None,
  129. **sampling_kwargs,
  130. ) -> torch.Tensor:
  131. # print(x, input_pos)
  132. x = model.forward_generate(x, input_pos)
  133. logits = x.logits # [:, -1:]
  134. hidden_states = x.hidden_states # [:, -1:]
  135. sampling_kwargs_main = sampling_kwargs.copy()
  136. sampling_kwargs_main["temperature"] = 0.1
  137. sampling_kwargs_main["top_p"] = 0.1
  138. sampling_kwargs_main["repetition_penalty"] = 1.0
  139. codebooks = [
  140. sample_agent(
  141. logits,
  142. previous_tokens=None, # Disable repetition penalty for the token codebook
  143. **sampling_kwargs_main,
  144. )[0]
  145. ]
  146. # Cleanup the cache
  147. for layer in model.fast_layers:
  148. layer.attention.kv_cache.k_cache.fill_(0)
  149. layer.attention.kv_cache.v_cache.fill_(0)
  150. for codebook_idx in range(model.config.num_codebooks):
  151. input_pos = torch.tensor(
  152. [codebook_idx], device=hidden_states.device, dtype=torch.long
  153. )
  154. logits = model.forward_generate_fast(hidden_states, input_pos)
  155. a = sample_agent(
  156. logits,
  157. previous_tokens=(
  158. previous_tokens[:, codebook_idx + 1]
  159. if previous_tokens is not None
  160. else None
  161. ),
  162. **sampling_kwargs,
  163. )[0]
  164. hidden_states = model.fast_embeddings(a)
  165. codebooks.append(a)
  166. codebooks = torch.stack(codebooks, dim=1)
  167. semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
  168. codebooks[:, 1:, :] = torch.masked_fill(
  169. codebooks[:, 1:, :],
  170. ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
  171. CODEBOOK_PAD_TOKEN_ID,
  172. )
  173. return codebooks
  174. def decode_one_token_naive_agent(
  175. model: NaiveTransformer,
  176. x: torch.Tensor,
  177. input_pos: torch.Tensor,
  178. semantic_ids: list,
  179. previous_tokens: torch.Tensor = None,
  180. **sampling_kwargs,
  181. ) -> torch.Tensor:
  182. x = model.forward_generate(x, input_pos)
  183. codebooks = [
  184. sample(
  185. x.token_logits,
  186. previous_tokens=None, # Disable repetition penalty for the token codebook
  187. **sampling_kwargs,
  188. )[0]
  189. ]
  190. for i in range(model.config.num_codebooks):
  191. codebooks.append(
  192. sample_agent(
  193. x.codebook_logits[:, :, i],
  194. previous_tokens=(
  195. previous_tokens[:, i + 1] if previous_tokens is not None else None
  196. ),
  197. **sampling_kwargs,
  198. )[0]
  199. )
  200. codebooks = torch.stack(codebooks, dim=1)
  201. semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
  202. codebooks[:, 1:, :] = torch.masked_fill(
  203. codebooks[:, 1:, :],
  204. ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
  205. CODEBOOK_PAD_TOKEN_ID,
  206. )
  207. return codebooks
  208. def decode_one_token_ar(
  209. model: DualARTransformer,
  210. x: torch.Tensor,
  211. input_pos: torch.Tensor,
  212. semantic_ids: list,
  213. previous_tokens: torch.Tensor = None,
  214. **sampling_kwargs,
  215. ) -> torch.Tensor:
  216. x = model.forward_generate(x, input_pos)
  217. sampling_kwargs_main = sampling_kwargs.copy()
  218. # sampling_kwargs_main["temperature"] = 0.1
  219. # sampling_kwargs_main["top_p"] = 0.1
  220. # sampling_kwargs_main["repetition_penalty"] = 1.0
  221. codebooks = [
  222. sample(
  223. x.logits,
  224. previous_tokens=(
  225. previous_tokens[0] if previous_tokens is not None else None
  226. ), # Disable repetition penalty for the token codebook
  227. **sampling_kwargs_main,
  228. )[0]
  229. ]
  230. hidden_states = x.hidden_states
  231. # Cleanup the cache
  232. for layer in model.fast_layers:
  233. layer.attention.kv_cache.k_cache.fill_(0)
  234. layer.attention.kv_cache.v_cache.fill_(0)
  235. input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
  236. model.forward_generate_fast(hidden_states, input_pos)
  237. a = codebooks[0] - model.tokenizer.semantic_begin_id
  238. a[a < 0] = 0
  239. hidden_states = model.fast_embeddings(a)
  240. codebooks.append(a)
  241. for codebook_idx in range(1, model.config.num_codebooks):
  242. input_pos = torch.tensor(
  243. [codebook_idx], device=hidden_states.device, dtype=torch.long
  244. )
  245. logits = model.forward_generate_fast(hidden_states, input_pos)
  246. a = sample(
  247. logits,
  248. previous_tokens=(
  249. previous_tokens[codebook_idx + 1]
  250. if previous_tokens is not None
  251. else None
  252. ),
  253. **sampling_kwargs,
  254. )[0]
  255. hidden_states = model.fast_embeddings(a)
  256. codebooks.append(a)
  257. codebooks = torch.stack(codebooks, dim=0)
  258. # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
  259. # codebooks[1:, :] = torch.masked_fill(
  260. # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
  261. # )
  262. # print(codebooks)
  263. return codebooks
  264. def decode_one_token_naive(
  265. model: NaiveTransformer,
  266. x: torch.Tensor,
  267. input_pos: torch.Tensor,
  268. previous_tokens: torch.Tensor = None,
  269. **sampling_kwargs,
  270. ) -> torch.Tensor:
  271. x = model.forward_generate(x, input_pos)
  272. sampling_kwargs_main = sampling_kwargs.copy()
  273. sampling_kwargs_main["temperature"] = 0.1
  274. sampling_kwargs_main["top_p"] = 0.1
  275. sampling_kwargs_main["repetition_penalty"] = 1.0
  276. codebooks = [
  277. sample(
  278. x.logits,
  279. previous_tokens=None, # Disable repetition penalty for the token codebook
  280. **sampling_kwargs_main,
  281. )[0]
  282. ]
  283. for i in range(model.config.num_codebooks):
  284. codebooks.append(
  285. sample(
  286. x.codebook_logits[:, :, i],
  287. previous_tokens=(
  288. previous_tokens[i + 1] if previous_tokens is not None else None
  289. ),
  290. **sampling_kwargs,
  291. )[0]
  292. )
  293. return torch.stack(codebooks, dim=0)
  294. def decode_n_tokens(
  295. model: NaiveTransformer,
  296. cur_token: torch.Tensor,
  297. input_pos: torch.Tensor,
  298. num_new_tokens: int,
  299. semantic_ids: list,
  300. decode_one_token=decode_one_token_naive,
  301. **sampling_kwargs,
  302. ):
  303. previous_tokens = torch.zeros(
  304. (model.config.num_codebooks + 1, model.config.max_seq_len),
  305. dtype=torch.int,
  306. device=cur_token.device,
  307. )
  308. for i in tqdm(range(num_new_tokens)):
  309. # We need to get windowed repeat penalty
  310. win_size = 16
  311. if i < win_size:
  312. window = previous_tokens[:, :win_size]
  313. else:
  314. window = previous_tokens[:, i - win_size : i]
  315. with (
  316. sdpa_kernel(
  317. [
  318. SDPBackend.FLASH_ATTENTION,
  319. SDPBackend.EFFICIENT_ATTENTION,
  320. SDPBackend.MATH,
  321. ]
  322. )
  323. if torch.cuda.is_available()
  324. else nullcontext()
  325. ): # Actually better for Inductor to codegen attention here
  326. next_token = decode_one_token(
  327. model=model,
  328. x=cur_token,
  329. input_pos=input_pos,
  330. previous_tokens=window,
  331. semantic_ids=semantic_ids,
  332. **sampling_kwargs,
  333. )
  334. input_pos += 1
  335. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  336. previous_tokens[:, i : i + 1] = next_token.view(
  337. model.config.num_codebooks + 1, -1
  338. )
  339. if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
  340. break
  341. return previous_tokens[:, : i + 1]
  342. @torch.no_grad()
  343. @torch.inference_mode()
  344. def generate(
  345. *,
  346. model: NaiveTransformer,
  347. prompt: torch.Tensor,
  348. max_new_tokens: int,
  349. decode_one_token=decode_one_token_naive,
  350. **sampling_kwargs,
  351. ) -> torch.Tensor:
  352. """
  353. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  354. """
  355. # create an empty tensor of the expected final shape and fill in the current tokens
  356. T = prompt.size(1)
  357. # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
  358. semantic_ids = [
  359. model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
  360. ]
  361. if max_new_tokens:
  362. if T + max_new_tokens > model.config.max_seq_len:
  363. max_new_tokens = model.config.max_seq_len - T
  364. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  365. T_new = T + max_new_tokens
  366. else:
  367. T_new = model.config.max_seq_len
  368. max_new_tokens = T_new - T
  369. device, dtype = prompt.device, prompt.dtype
  370. codebook_dim = 1 + model.config.num_codebooks
  371. # create an empty tensor of the expected final shape and fill in the current tokens
  372. empty = torch.empty(
  373. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  374. )
  375. empty[:, :T] = prompt
  376. seq = empty
  377. input_pos = torch.arange(0, T, device=device)
  378. # Use non-accelerated version for now, to avoid compilation overhead
  379. prefill_decode = (
  380. decode_one_token_naive
  381. if isinstance(model, NaiveTransformer)
  382. else decode_one_token_ar
  383. )
  384. next_token = prefill_decode(
  385. model,
  386. prompt.view(1, codebook_dim, -1),
  387. input_pos,
  388. semantic_ids=semantic_ids,
  389. **sampling_kwargs,
  390. )
  391. seq[:, T : T + 1] = next_token
  392. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  393. x = decode_n_tokens(
  394. model,
  395. next_token.view(1, codebook_dim, -1),
  396. input_pos,
  397. max_new_tokens - 1,
  398. decode_one_token=decode_one_token,
  399. semantic_ids=semantic_ids,
  400. **sampling_kwargs,
  401. )
  402. # x = torch.cat(generated_tokens, dim=1)
  403. seq = seq[:, : T + 1 + x.size(1)]
  404. seq[:, T + 1 :] = x
  405. return seq
  406. def decode_n_tokens_agent(
  407. model: NaiveTransformer,
  408. cur_token: torch.Tensor,
  409. input_pos: torch.Tensor,
  410. num_new_tokens: int,
  411. semantic_ids: list,
  412. im_end_id: int = 4,
  413. decode_one_token=decode_one_token_naive_agent,
  414. early_stop_threshold: float = 0.6,
  415. **sampling_kwargs,
  416. ):
  417. batch_size = cur_token.size(0)
  418. previous_tokens = torch.zeros(
  419. (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
  420. dtype=torch.int,
  421. device=cur_token.device,
  422. )
  423. finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
  424. finished = finished | (cur_token[:, 0, -1] == im_end_id)
  425. start_time = time.time()
  426. for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
  427. # We need to get windowed repeat penalty
  428. win_size = 16
  429. if i < win_size:
  430. window = previous_tokens[:, :, :win_size]
  431. else:
  432. window = previous_tokens[:, :, i - win_size : i]
  433. with sdpa_kernel(
  434. SDPBackend.MATH
  435. ): # Actually better for Inductor to codegen attention here
  436. next_token = decode_one_token(
  437. model=model,
  438. x=cur_token,
  439. input_pos=input_pos,
  440. previous_tokens=window,
  441. semantic_ids=semantic_ids,
  442. **sampling_kwargs,
  443. )
  444. input_pos += 1
  445. cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
  446. previous_tokens[:, :, i : i + 1] = next_token.view(
  447. batch_size, model.config.num_codebooks + 1, -1
  448. )
  449. yield cur_token.cpu()
  450. finished = finished | (cur_token[:, 0, -1] == im_end_id)
  451. if finished.all() or (
  452. 0 < early_stop_threshold < 1
  453. and finished.sum() >= round(batch_size * early_stop_threshold)
  454. ):
  455. break
  456. total_time = time.time() - start_time
  457. generated_tokens = i + 1
  458. tokens_per_second = (generated_tokens / total_time) * batch_size
  459. logger.info(
  460. f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
  461. )
  462. @torch.no_grad()
  463. @torch.inference_mode()
  464. def generate_agent(
  465. *,
  466. model: BaseTransformer,
  467. prompt: torch.Tensor,
  468. max_new_tokens: int,
  469. semantic_ids: list,
  470. im_end_id: int = 4,
  471. decode_one_token=decode_one_token_naive_agent,
  472. num_samples: int = 1,
  473. early_stop_threshold: float = 0.6,
  474. **sampling_kwargs,
  475. ):
  476. """
  477. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  478. """
  479. # create an empty tensor of the expected final shape and fill in the current tokens
  480. T = prompt.size(1)
  481. prompt = prompt[None].repeat(num_samples, 1, 1)
  482. if T >= model.config.max_seq_len:
  483. raise ValueError(
  484. f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
  485. )
  486. if max_new_tokens:
  487. if T + max_new_tokens > model.config.max_seq_len:
  488. max_new_tokens = model.config.max_seq_len - T
  489. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  490. T_new = T + max_new_tokens
  491. else:
  492. T_new = model.config.max_seq_len
  493. max_new_tokens = T_new - T
  494. device, dtype = prompt.device, prompt.dtype
  495. codebook_dim = 1 + model.config.num_codebooks
  496. input_pos = torch.arange(0, T, device=device)
  497. # Use non-accelerated version for now, to avoid compilation overhead
  498. prefill_decode = (
  499. decode_one_token_naive_agent
  500. if isinstance(model, NaiveTransformer)
  501. else decode_one_token_ar_agent
  502. )
  503. next_token = prefill_decode(
  504. model,
  505. prompt,
  506. input_pos,
  507. semantic_ids=semantic_ids,
  508. **sampling_kwargs,
  509. ).view(num_samples, codebook_dim, -1)
  510. yield next_token.cpu()
  511. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  512. yield from decode_n_tokens_agent(
  513. model,
  514. next_token,
  515. input_pos,
  516. max_new_tokens - 1,
  517. im_end_id=im_end_id,
  518. semantic_ids=semantic_ids,
  519. decode_one_token=decode_one_token,
  520. early_stop_threshold=early_stop_threshold,
  521. **sampling_kwargs,
  522. )
  523. def encode_tokens(
  524. tokenizer,
  525. string,
  526. device="cuda",
  527. prompt_tokens=None,
  528. num_codebooks=4,
  529. ):
  530. string = clean_text(string)
  531. messages = []
  532. messages.append(
  533. Message(
  534. role="user",
  535. parts=[TextPart(text=string)],
  536. cal_loss=False,
  537. )
  538. )
  539. if prompt_tokens is not None:
  540. if prompt_tokens.ndim == 3:
  541. assert (
  542. prompt_tokens.shape[0] == 1
  543. ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
  544. prompt_tokens = prompt_tokens[0]
  545. assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
  546. if prompt_tokens.shape[0] > num_codebooks:
  547. logger.warning(
  548. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  549. )
  550. prompt_tokens = prompt_tokens[:num_codebooks]
  551. vq_part = VQPart(codes=prompt_tokens.to(device))
  552. messages.append(
  553. Message(
  554. role="assistant",
  555. parts=[TextPart(text="<|voice|>"), vq_part],
  556. cal_loss=False,
  557. )
  558. )
  559. else:
  560. messages.append(
  561. Message(
  562. role="assistant",
  563. parts=[TextPart(text="<|voice|>")],
  564. cal_loss=False,
  565. add_im_end=False,
  566. )
  567. )
  568. conversation = Conversation(messages=messages)
  569. # conversation.visualize(tokenizer)
  570. encoded = conversation.encode_for_inference(
  571. tokenizer=tokenizer,
  572. num_codebooks=num_codebooks,
  573. )
  574. return encoded.to(device)
  575. def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
  576. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  577. checkpoint_path, load_weights=True, is_agent=is_agent
  578. )
  579. model = model.to(device=device, dtype=precision)
  580. logger.info(f"Restored model from checkpoint")
  581. if isinstance(model, DualARTransformer):
  582. decode_one_token = (
  583. decode_one_token_ar_agent if is_agent else decode_one_token_ar
  584. )
  585. logger.info("Using DualARTransformer")
  586. else:
  587. decode_one_token = (
  588. decode_one_token_naive_agent if is_agent else decode_one_token_naive
  589. )
  590. logger.info("Using NaiveTransformer")
  591. if compile:
  592. logger.info("Compiling function...")
  593. decode_one_token = torch.compile(
  594. decode_one_token,
  595. fullgraph=True,
  596. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  597. mode="reduce-overhead" if torch.cuda.is_available() else None,
  598. )
  599. return model.eval(), decode_one_token
  600. @dataclass
  601. class GenerateResponse:
  602. action: Literal["sample", "next"]
  603. codes: Optional[torch.Tensor] = None
  604. text: Optional[str] = None
  605. def generate_long(
  606. *,
  607. model,
  608. device: str | torch.device,
  609. decode_one_token: callable,
  610. text: str,
  611. num_samples: int = 1,
  612. max_new_tokens: int = 0,
  613. top_p: int = 0.7,
  614. repetition_penalty: float = 1.5,
  615. temperature: float = 0.7,
  616. compile: bool = False,
  617. iterative_prompt: bool = True,
  618. max_length: int = 2048,
  619. chunk_length: int = 150,
  620. prompt_text: Optional[str | list[str]] = None,
  621. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  622. ):
  623. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  624. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  625. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  626. use_prompt = prompt_text is not None and prompt_tokens is not None
  627. if use_prompt and isinstance(prompt_text, str):
  628. prompt_text = [prompt_text]
  629. prompt_tokens = [prompt_tokens]
  630. assert use_prompt is False or len(prompt_text) == len(
  631. prompt_tokens
  632. ), "Prompt text and tokens must have the same length"
  633. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  634. tokenizer = model.tokenizer
  635. im_end_id = tokenizer.get_token_id("<|im_end|>")
  636. encoded = []
  637. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  638. encoded_prompts = [
  639. Conversation(
  640. messages=[
  641. Message(
  642. role="system",
  643. parts=[TextPart(text="Speak out the provided text.")],
  644. cal_loss=False,
  645. )
  646. ]
  647. )
  648. .encode_for_inference(
  649. tokenizer=tokenizer,
  650. num_codebooks=model.config.num_codebooks,
  651. )
  652. .to(device)
  653. ]
  654. if use_prompt:
  655. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  656. encoded_prompts.append(
  657. encode_tokens(
  658. tokenizer,
  659. string=t,
  660. device=device,
  661. prompt_tokens=c,
  662. num_codebooks=model.config.num_codebooks,
  663. )
  664. )
  665. for idx, text in enumerate(texts):
  666. encoded.append(
  667. encode_tokens(
  668. tokenizer,
  669. string=text,
  670. device=device,
  671. num_codebooks=model.config.num_codebooks,
  672. )
  673. )
  674. logger.info(f"Encoded text: {text}")
  675. # Move temperature, top_p, repetition_penalty to device
  676. # This is important so that changing params doesn't trigger recompile
  677. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  678. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  679. repetition_penalty = torch.tensor(
  680. repetition_penalty, device=device, dtype=torch.float
  681. )
  682. for sample_idx in range(num_samples):
  683. if torch.cuda.is_available():
  684. torch.cuda.synchronize()
  685. global_encoded = []
  686. seg_idx = 0
  687. while seg_idx < len(encoded):
  688. logger.info(
  689. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  690. )
  691. seg = encoded[seg_idx]
  692. global_encoded.append(seg)
  693. lengths = reversed([seg.size(1) for seg in global_encoded])
  694. # Pick last 2000 tokens
  695. count = 0
  696. for i, length in enumerate(lengths):
  697. count += length
  698. if count + length > max_length - 1024 - sum(
  699. t.shape[1] for t in encoded_prompts
  700. ):
  701. break
  702. if i != 0 and i % 2 == 0:
  703. i -= 1
  704. # Rotate the list, always make sure first segment is included to avoid drift
  705. if i < len(global_encoded) - 2:
  706. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  707. else:
  708. partial_encoded = global_encoded
  709. if use_prompt:
  710. partial_encoded = encoded_prompts + partial_encoded
  711. cat_encoded = torch.cat(partial_encoded, dim=1)
  712. prompt_length = cat_encoded.size(1)
  713. t0 = time.perf_counter()
  714. y = generate(
  715. model=model,
  716. prompt=cat_encoded,
  717. max_new_tokens=max_new_tokens,
  718. decode_one_token=decode_one_token,
  719. temperature=temperature,
  720. top_p=top_p,
  721. repetition_penalty=repetition_penalty,
  722. )
  723. if sample_idx == 0 and seg_idx == 0 and compile:
  724. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  725. if torch.cuda.is_available():
  726. torch.cuda.synchronize()
  727. t = time.perf_counter() - t0
  728. tokens_generated = y.size(1) - prompt_length
  729. tokens_sec = tokens_generated / t
  730. logger.info(
  731. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  732. )
  733. logger.info(
  734. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  735. )
  736. if torch.cuda.is_available():
  737. logger.info(
  738. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  739. )
  740. # Put the generated tokens
  741. # since there is <im_end>, we remove last token
  742. codes = y[1:, prompt_length + 1 :].clone()
  743. assert (codes >= 0).all(), f"Negative code found"
  744. decoded = y[:, prompt_length:].clone()
  745. # But for global encoding, we should keep the <im_end> token
  746. global_encoded.append(decoded)
  747. assert (codes >= 0).all(), f"Negative code found: {codes}"
  748. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  749. seg_idx += 1
  750. # This indicates the end of the current sample
  751. yield GenerateResponse(action="next")
  752. @dataclass
  753. class WrappedGenerateResponse:
  754. status: Literal["success", "error"]
  755. response: Optional[GenerateResponse | Exception] = None
  756. @dataclass
  757. class GenerateRequest:
  758. request: dict
  759. response_queue: queue.Queue
  760. def launch_thread_safe_queue(
  761. checkpoint_path,
  762. device,
  763. precision,
  764. compile: bool = False,
  765. ):
  766. input_queue = queue.Queue()
  767. init_event = threading.Event()
  768. def worker():
  769. model, decode_one_token = load_model(
  770. checkpoint_path, device, precision, compile=compile
  771. )
  772. with torch.device(device):
  773. model.setup_caches(
  774. max_batch_size=1,
  775. max_seq_len=model.config.max_seq_len,
  776. dtype=next(model.parameters()).dtype,
  777. )
  778. init_event.set()
  779. while True:
  780. item: GenerateRequest | None = input_queue.get()
  781. if item is None:
  782. break
  783. kwargs = item.request
  784. response_queue = item.response_queue
  785. try:
  786. for chunk in generate_long(
  787. model=model, decode_one_token=decode_one_token, **kwargs
  788. ):
  789. response_queue.put(
  790. WrappedGenerateResponse(status="success", response=chunk)
  791. )
  792. except Exception as e:
  793. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  794. threading.Thread(target=worker, daemon=True).start()
  795. init_event.wait()
  796. return input_queue
  797. def launch_thread_safe_queue_agent(
  798. checkpoint_path,
  799. device,
  800. precision,
  801. compile: bool = False,
  802. ):
  803. input_queue = queue.Queue()
  804. init_event = threading.Event()
  805. tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
  806. config = BaseModelArgs.from_pretrained(checkpoint_path)
  807. def worker():
  808. model, decode_one_token = load_model(
  809. checkpoint_path, device, precision, compile=compile, is_agent=True
  810. )
  811. with torch.device(device):
  812. model.setup_caches(
  813. max_batch_size=1,
  814. max_seq_len=model.config.max_seq_len,
  815. dtype=next(model.parameters()).dtype,
  816. )
  817. init_event.set()
  818. while True:
  819. item: GenerateRequest | None = input_queue.get()
  820. if item is None:
  821. break
  822. kwargs = item.request
  823. response_queue = item.response_queue
  824. try:
  825. for token in generate_agent(
  826. model=model,
  827. decode_one_token=decode_one_token,
  828. **kwargs,
  829. ):
  830. response_queue.put(token)
  831. response_queue.put("stop")
  832. except Exception as e:
  833. import traceback
  834. logger.exception(f"Error in worker: {traceback.format_exc()}")
  835. response_queue.put("error")
  836. threading.Thread(target=worker, daemon=True).start()
  837. init_event.wait()
  838. return input_queue, tokenizer, config
  839. @click.command()
  840. @click.option(
  841. "--text",
  842. type=str,
  843. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  844. )
  845. @click.option("--prompt-text", type=str, default=None, multiple=True)
  846. @click.option(
  847. "--prompt-tokens",
  848. type=click.Path(path_type=Path, exists=True),
  849. default=None,
  850. multiple=True,
  851. )
  852. @click.option("--num-samples", type=int, default=1)
  853. @click.option("--max-new-tokens", type=int, default=0)
  854. @click.option("--top-p", type=float, default=0.7)
  855. @click.option("--repetition-penalty", type=float, default=1.2)
  856. @click.option("--temperature", type=float, default=0.7)
  857. @click.option(
  858. "--checkpoint-path",
  859. type=click.Path(path_type=Path, exists=True),
  860. default="checkpoints/fish-speech-1.5",
  861. )
  862. @click.option("--device", type=str, default="cuda")
  863. @click.option("--compile/--no-compile", default=False)
  864. @click.option("--seed", type=int, default=42)
  865. @click.option("--half/--no-half", default=False)
  866. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  867. @click.option("--chunk-length", type=int, default=100)
  868. @click.option("--output-dir", type=Path, default="temp")
  869. def main(
  870. text: str,
  871. prompt_text: Optional[list[str]],
  872. prompt_tokens: Optional[list[Path]],
  873. num_samples: int,
  874. max_new_tokens: int,
  875. top_p: int,
  876. repetition_penalty: float,
  877. temperature: float,
  878. checkpoint_path: Path,
  879. device: str,
  880. compile: bool,
  881. seed: int,
  882. half: bool,
  883. iterative_prompt: bool,
  884. chunk_length: int,
  885. output_dir: Path,
  886. ) -> None:
  887. os.makedirs(output_dir, exist_ok=True)
  888. precision = torch.half if half else torch.bfloat16
  889. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  890. raise ValueError(
  891. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  892. )
  893. logger.info("Loading model ...")
  894. t0 = time.time()
  895. model, decode_one_token = load_model(
  896. checkpoint_path, device, precision, compile=compile
  897. )
  898. with torch.device(device):
  899. model.setup_caches(
  900. max_batch_size=1,
  901. max_seq_len=model.config.max_seq_len,
  902. dtype=next(model.parameters()).dtype,
  903. )
  904. if torch.cuda.is_available():
  905. torch.cuda.synchronize()
  906. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  907. if prompt_tokens is not None:
  908. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  909. torch.manual_seed(seed)
  910. if torch.cuda.is_available():
  911. torch.cuda.manual_seed(seed)
  912. generator = generate_long(
  913. model=model,
  914. device=device,
  915. decode_one_token=decode_one_token,
  916. text=text,
  917. num_samples=num_samples,
  918. max_new_tokens=max_new_tokens,
  919. top_p=top_p,
  920. repetition_penalty=repetition_penalty,
  921. temperature=temperature,
  922. compile=compile,
  923. iterative_prompt=iterative_prompt,
  924. chunk_length=chunk_length,
  925. prompt_text=prompt_text,
  926. prompt_tokens=prompt_tokens,
  927. )
  928. idx = 0
  929. codes = []
  930. for response in generator:
  931. if response.action == "sample":
  932. codes.append(response.codes)
  933. logger.info(f"Sampled text: {response.text}")
  934. elif response.action == "next":
  935. if codes:
  936. codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
  937. np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
  938. logger.info(f"Saved codes to {codes_npy_path}")
  939. logger.info(f"Next sample")
  940. codes = []
  941. idx += 1
  942. else:
  943. logger.error(f"Error: {response}")
  944. if __name__ == "__main__":
  945. main()