generate.py 34 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115
  1. import os
  2. import queue
  3. import threading
  4. import time
  5. from contextlib import nullcontext
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import Literal, Optional, Tuple, Union
  9. import click
  10. import hydra
  11. import numpy as np
  12. import torch
  13. import torch._dynamo.config
  14. import torch._inductor.config
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from transformers import AutoTokenizer
  18. from fish_speech.conversation import (
  19. CODEBOOK_PAD_TOKEN_ID,
  20. Conversation,
  21. Message,
  22. TextPart,
  23. VQPart,
  24. )
  25. from fish_speech.models.text2semantic.llama import BaseModelArgs
  26. from fish_speech.text import clean_text, split_text
  27. from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
  28. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  29. torch._inductor.config.coordinate_descent_tuning = True
  30. torch._inductor.config.triton.unique_kernel_names = True
  31. if hasattr(torch._inductor.config, "fx_graph_cache"):
  32. # Experimental feature to reduce compilation times, will be on by default in future
  33. torch._inductor.config.fx_graph_cache = True
  34. from torch.nn.attention import SDPBackend, sdpa_kernel
  35. from fish_speech.models.text2semantic.llama import (
  36. BaseTransformer,
  37. DualARTransformer,
  38. NaiveTransformer,
  39. )
  40. def multinomial_sample_one_no_sync(
  41. probs_sort,
  42. ): # Does multinomial sampling without a cuda synchronization
  43. q = torch.empty_like(probs_sort).exponential_(1)
  44. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  45. def logits_to_probs(
  46. logits,
  47. previous_tokens: Optional[torch.Tensor] = None,
  48. temperature: torch.Tensor = 1.0,
  49. top_p: torch.Tensor = 1.0,
  50. repetition_penalty: torch.Tensor = 1.0,
  51. ) -> torch.Tensor:
  52. # Apply repetition penalty
  53. if previous_tokens is not None:
  54. previous_tokens = previous_tokens.long()
  55. score = torch.gather(logits, dim=0, index=previous_tokens)
  56. score = torch.where(
  57. score < 0, score * repetition_penalty, score / repetition_penalty
  58. )
  59. logits.scatter_(dim=0, index=previous_tokens, src=score)
  60. # Apply top-p sampling
  61. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  62. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  63. sorted_indices_to_remove = cum_probs > top_p
  64. sorted_indices_to_remove[0] = False # keep at least one option
  65. indices_to_remove = sorted_indices_to_remove.scatter(
  66. dim=0, index=sorted_indices, src=sorted_indices_to_remove
  67. )
  68. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  69. logits = logits / max(temperature, 1e-5)
  70. probs = torch.nn.functional.softmax(logits, dim=-1)
  71. return probs
  72. def multinomial_sample_one_no_sync_agent(
  73. probs_sort,
  74. ): # Does multinomial sampling without a cuda synchronization
  75. q = torch.empty_like(probs_sort).exponential_(1)
  76. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  77. def logits_to_probs_agent(
  78. logits,
  79. previous_tokens: Optional[torch.Tensor] = None,
  80. temperature: torch.Tensor = 1.0,
  81. top_p: torch.Tensor = 1.0,
  82. repetition_penalty: torch.Tensor = 1.0,
  83. ) -> torch.Tensor:
  84. # Apply repetition penalty
  85. if previous_tokens is not None:
  86. previous_tokens = previous_tokens.long()
  87. score = torch.gather(logits, dim=-1, index=previous_tokens)
  88. score = torch.where(
  89. score < 0, score * repetition_penalty, score / repetition_penalty
  90. )
  91. logits.scatter_(dim=-1, index=previous_tokens, src=score)
  92. # Apply top-p sampling
  93. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  94. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  95. sorted_indices_to_remove = cum_probs > top_p
  96. sorted_indices_to_remove[..., 0] = False # keep at least one option
  97. indices_to_remove = sorted_indices_to_remove.scatter(
  98. dim=-1, index=sorted_indices, src=sorted_indices_to_remove
  99. )
  100. logits = logits.masked_fill(indices_to_remove, -float("Inf"))
  101. logits = logits / max(temperature, 1e-5)
  102. probs = torch.nn.functional.softmax(logits, dim=-1)
  103. return probs
  104. def sample(
  105. logits,
  106. previous_tokens: Optional[torch.Tensor] = None,
  107. **sampling_kwargs,
  108. ) -> Tuple[torch.Tensor, torch.Tensor]:
  109. probs = logits_to_probs(
  110. logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
  111. )
  112. idx_next = multinomial_sample_one_no_sync(probs)
  113. return idx_next, probs
  114. def sample_agent(
  115. logits,
  116. previous_tokens: Optional[torch.Tensor] = None,
  117. **sampling_kwargs,
  118. ) -> Tuple[torch.Tensor, torch.Tensor]:
  119. probs = logits_to_probs_agent(
  120. logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
  121. )
  122. idx_next = multinomial_sample_one_no_sync_agent(probs)
  123. return idx_next, probs
  124. def decode_one_token_ar_agent(
  125. model: DualARTransformer,
  126. x: torch.Tensor,
  127. input_pos: torch.Tensor,
  128. semantic_ids: list,
  129. previous_tokens: torch.Tensor = None,
  130. **sampling_kwargs,
  131. ) -> torch.Tensor:
  132. # print(x, input_pos)
  133. x = model.forward_generate(x, input_pos)
  134. logits = x.logits # [:, -1:]
  135. hidden_states = x.hidden_states # [:, -1:]
  136. sampling_kwargs_main = sampling_kwargs.copy()
  137. sampling_kwargs_main["temperature"] = 0.1
  138. sampling_kwargs_main["top_p"] = 0.1
  139. sampling_kwargs_main["repetition_penalty"] = 1.0
  140. codebooks = [
  141. sample_agent(
  142. logits,
  143. previous_tokens=None, # Disable repetition penalty for the token codebook
  144. **sampling_kwargs_main,
  145. )[0]
  146. ]
  147. # Cleanup the cache
  148. for layer in model.fast_layers:
  149. layer.attention.kv_cache.k_cache.fill_(0)
  150. layer.attention.kv_cache.v_cache.fill_(0)
  151. for codebook_idx in range(model.config.num_codebooks):
  152. input_pos = torch.tensor(
  153. [codebook_idx], device=hidden_states.device, dtype=torch.long
  154. )
  155. logits = model.forward_generate_fast(hidden_states, input_pos)
  156. a = sample_agent(
  157. logits,
  158. previous_tokens=(
  159. previous_tokens[:, codebook_idx + 1]
  160. if previous_tokens is not None
  161. else None
  162. ),
  163. **sampling_kwargs,
  164. )[0]
  165. hidden_states = model.fast_embeddings(a)
  166. codebooks.append(a)
  167. codebooks = torch.stack(codebooks, dim=1)
  168. semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
  169. codebooks[:, 1:, :] = torch.masked_fill(
  170. codebooks[:, 1:, :],
  171. ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
  172. CODEBOOK_PAD_TOKEN_ID,
  173. )
  174. return codebooks
  175. def decode_one_token_naive_agent(
  176. model: NaiveTransformer,
  177. x: torch.Tensor,
  178. input_pos: torch.Tensor,
  179. semantic_ids: list,
  180. previous_tokens: torch.Tensor = None,
  181. **sampling_kwargs,
  182. ) -> torch.Tensor:
  183. x = model.forward_generate(x, input_pos)
  184. codebooks = [
  185. sample(
  186. x.token_logits,
  187. previous_tokens=None, # Disable repetition penalty for the token codebook
  188. **sampling_kwargs,
  189. )[0]
  190. ]
  191. for i in range(model.config.num_codebooks):
  192. codebooks.append(
  193. sample_agent(
  194. x.codebook_logits[:, :, i],
  195. previous_tokens=(
  196. previous_tokens[:, i + 1] if previous_tokens is not None else None
  197. ),
  198. **sampling_kwargs,
  199. )[0]
  200. )
  201. codebooks = torch.stack(codebooks, dim=1)
  202. semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
  203. codebooks[:, 1:, :] = torch.masked_fill(
  204. codebooks[:, 1:, :],
  205. ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
  206. CODEBOOK_PAD_TOKEN_ID,
  207. )
  208. return codebooks
  209. def decode_one_token_ar(
  210. model: DualARTransformer,
  211. x: torch.Tensor,
  212. input_pos: torch.Tensor,
  213. semantic_ids: list,
  214. previous_tokens: torch.Tensor = None,
  215. **sampling_kwargs,
  216. ) -> torch.Tensor:
  217. x = model.forward_generate(x, input_pos)
  218. sampling_kwargs_main = sampling_kwargs.copy()
  219. # sampling_kwargs_main["temperature"] = 0.1
  220. # sampling_kwargs_main["top_p"] = 0.1
  221. # sampling_kwargs_main["repetition_penalty"] = 1.0
  222. codebooks = [
  223. sample(
  224. x.logits,
  225. previous_tokens=(
  226. previous_tokens[0] if previous_tokens is not None else None
  227. ), # Disable repetition penalty for the token codebook
  228. **sampling_kwargs_main,
  229. )[0]
  230. ]
  231. hidden_states = x.hidden_states
  232. # Cleanup the cache
  233. for layer in model.fast_layers:
  234. layer.attention.kv_cache.k_cache.fill_(0)
  235. layer.attention.kv_cache.v_cache.fill_(0)
  236. input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
  237. model.forward_generate_fast(hidden_states, input_pos)
  238. a = codebooks[0] - model.tokenizer.semantic_begin_id
  239. a[a < 0] = 0
  240. hidden_states = model.fast_embeddings(a)
  241. codebooks.append(a)
  242. for codebook_idx in range(1, model.config.num_codebooks):
  243. input_pos = torch.tensor(
  244. [codebook_idx], device=hidden_states.device, dtype=torch.long
  245. )
  246. logits = model.forward_generate_fast(hidden_states, input_pos)
  247. a = sample(
  248. logits,
  249. previous_tokens=(
  250. previous_tokens[codebook_idx + 1]
  251. if previous_tokens is not None
  252. else None
  253. ),
  254. **sampling_kwargs,
  255. )[0]
  256. hidden_states = model.fast_embeddings(a)
  257. codebooks.append(a)
  258. codebooks = torch.stack(codebooks, dim=0)
  259. # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
  260. # codebooks[1:, :] = torch.masked_fill(
  261. # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
  262. # )
  263. # print(codebooks)
  264. return codebooks
  265. def decode_one_token_naive(
  266. model: NaiveTransformer,
  267. x: torch.Tensor,
  268. input_pos: torch.Tensor,
  269. previous_tokens: torch.Tensor = None,
  270. **sampling_kwargs,
  271. ) -> torch.Tensor:
  272. x = model.forward_generate(x, input_pos)
  273. sampling_kwargs_main = sampling_kwargs.copy()
  274. sampling_kwargs_main["temperature"] = 0.1
  275. sampling_kwargs_main["top_p"] = 0.1
  276. sampling_kwargs_main["repetition_penalty"] = 1.0
  277. codebooks = [
  278. sample(
  279. x.logits,
  280. previous_tokens=None, # Disable repetition penalty for the token codebook
  281. **sampling_kwargs_main,
  282. )[0]
  283. ]
  284. for i in range(model.config.num_codebooks):
  285. codebooks.append(
  286. sample(
  287. x.codebook_logits[:, :, i],
  288. previous_tokens=(
  289. previous_tokens[i + 1] if previous_tokens is not None else None
  290. ),
  291. **sampling_kwargs,
  292. )[0]
  293. )
  294. return torch.stack(codebooks, dim=0)
  295. def decode_n_tokens(
  296. model: NaiveTransformer,
  297. cur_token: torch.Tensor,
  298. input_pos: torch.Tensor,
  299. num_new_tokens: int,
  300. semantic_ids: list,
  301. decode_one_token=decode_one_token_naive,
  302. **sampling_kwargs,
  303. ):
  304. previous_tokens = torch.zeros(
  305. (model.config.num_codebooks + 1, model.config.max_seq_len),
  306. dtype=torch.int,
  307. device=cur_token.device,
  308. )
  309. for i in tqdm(range(num_new_tokens)):
  310. # We need to get windowed repeat penalty
  311. win_size = 16
  312. if i < win_size:
  313. window = previous_tokens[:, :win_size]
  314. else:
  315. window = previous_tokens[:, i - win_size : i]
  316. with (
  317. torch.backends.cuda.sdp_kernel(
  318. enable_flash=False, enable_mem_efficient=False, enable_math=True
  319. )
  320. if torch.cuda.is_available()
  321. else nullcontext()
  322. ): # Actually better for Inductor to codegen attention here
  323. next_token = decode_one_token(
  324. model=model,
  325. x=cur_token,
  326. input_pos=input_pos,
  327. previous_tokens=window,
  328. semantic_ids=semantic_ids,
  329. **sampling_kwargs,
  330. )
  331. input_pos += 1
  332. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  333. previous_tokens[:, i : i + 1] = next_token.view(
  334. model.config.num_codebooks + 1, -1
  335. )
  336. if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
  337. break
  338. return previous_tokens[:, : i + 1]
  339. @torch.no_grad()
  340. @torch.inference_mode()
  341. def generate(
  342. *,
  343. model: NaiveTransformer,
  344. prompt: torch.Tensor,
  345. max_new_tokens: int,
  346. decode_one_token=decode_one_token_naive,
  347. **sampling_kwargs,
  348. ) -> torch.Tensor:
  349. """
  350. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  351. """
  352. # create an empty tensor of the expected final shape and fill in the current tokens
  353. T = prompt.size(1)
  354. # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
  355. semantic_ids = [
  356. model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
  357. ]
  358. if max_new_tokens:
  359. if T + max_new_tokens > model.config.max_seq_len:
  360. max_new_tokens = model.config.max_seq_len - T
  361. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  362. T_new = T + max_new_tokens
  363. else:
  364. T_new = model.config.max_seq_len
  365. max_new_tokens = T_new - T
  366. device, dtype = prompt.device, prompt.dtype
  367. codebook_dim = 1 + model.config.num_codebooks
  368. # create an empty tensor of the expected final shape and fill in the current tokens
  369. empty = torch.empty(
  370. (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
  371. )
  372. empty[:, :T] = prompt
  373. seq = empty
  374. input_pos = torch.arange(0, T, device=device)
  375. # Use non-accelerated version for now, to avoid compilation overhead
  376. prefill_decode = (
  377. decode_one_token_naive
  378. if isinstance(model, NaiveTransformer)
  379. else decode_one_token_ar
  380. )
  381. next_token = prefill_decode(
  382. model,
  383. prompt.view(1, codebook_dim, -1),
  384. input_pos,
  385. semantic_ids=semantic_ids,
  386. **sampling_kwargs,
  387. )
  388. seq[:, T : T + 1] = next_token
  389. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  390. x = decode_n_tokens(
  391. model,
  392. next_token.view(1, codebook_dim, -1),
  393. input_pos,
  394. max_new_tokens - 1,
  395. decode_one_token=decode_one_token,
  396. semantic_ids=semantic_ids,
  397. **sampling_kwargs,
  398. )
  399. # x = torch.cat(generated_tokens, dim=1)
  400. seq = seq[:, : T + 1 + x.size(1)]
  401. seq[:, T + 1 :] = x
  402. return seq
  403. def decode_n_tokens_agent(
  404. model: NaiveTransformer,
  405. cur_token: torch.Tensor,
  406. input_pos: torch.Tensor,
  407. num_new_tokens: int,
  408. semantic_ids: list,
  409. im_end_id: int = 4,
  410. decode_one_token=decode_one_token_naive_agent,
  411. early_stop_threshold: float = 0.6,
  412. **sampling_kwargs,
  413. ):
  414. batch_size = cur_token.size(0)
  415. previous_tokens = torch.zeros(
  416. (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
  417. dtype=torch.int,
  418. device=cur_token.device,
  419. )
  420. finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
  421. finished = finished | (cur_token[:, 0, -1] == im_end_id)
  422. start_time = time.time()
  423. for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
  424. # We need to get windowed repeat penalty
  425. win_size = 16
  426. if i < win_size:
  427. window = previous_tokens[:, :, :win_size]
  428. else:
  429. window = previous_tokens[:, :, i - win_size : i]
  430. with sdpa_kernel(
  431. SDPBackend.MATH
  432. ): # Actually better for Inductor to codegen attention here
  433. next_token = decode_one_token(
  434. model=model,
  435. x=cur_token,
  436. input_pos=input_pos,
  437. previous_tokens=window,
  438. semantic_ids=semantic_ids,
  439. **sampling_kwargs,
  440. )
  441. input_pos += 1
  442. cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
  443. previous_tokens[:, :, i : i + 1] = next_token.view(
  444. batch_size, model.config.num_codebooks + 1, -1
  445. )
  446. yield cur_token.cpu()
  447. finished = finished | (cur_token[:, 0, -1] == im_end_id)
  448. if finished.all() or (
  449. 0 < early_stop_threshold < 1
  450. and finished.sum() >= round(batch_size * early_stop_threshold)
  451. ):
  452. break
  453. total_time = time.time() - start_time
  454. generated_tokens = i + 1
  455. tokens_per_second = (generated_tokens / total_time) * batch_size
  456. logger.info(
  457. f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
  458. )
  459. @torch.no_grad()
  460. @torch.inference_mode()
  461. def generate_agent(
  462. *,
  463. model: BaseTransformer,
  464. prompt: torch.Tensor,
  465. max_new_tokens: int,
  466. semantic_ids: list,
  467. im_end_id: int = 4,
  468. decode_one_token=decode_one_token_naive_agent,
  469. num_samples: int = 1,
  470. early_stop_threshold: float = 0.6,
  471. **sampling_kwargs,
  472. ):
  473. """
  474. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  475. """
  476. # create an empty tensor of the expected final shape and fill in the current tokens
  477. T = prompt.size(1)
  478. prompt = prompt[None].repeat(num_samples, 1, 1)
  479. if T >= model.config.max_seq_len:
  480. raise ValueError(
  481. f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
  482. )
  483. if max_new_tokens:
  484. if T + max_new_tokens > model.config.max_seq_len:
  485. max_new_tokens = model.config.max_seq_len - T
  486. logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
  487. T_new = T + max_new_tokens
  488. else:
  489. T_new = model.config.max_seq_len
  490. max_new_tokens = T_new - T
  491. device, dtype = prompt.device, prompt.dtype
  492. codebook_dim = 1 + model.config.num_codebooks
  493. input_pos = torch.arange(0, T, device=device)
  494. # Use non-accelerated version for now, to avoid compilation overhead
  495. prefill_decode = (
  496. decode_one_token_naive_agent
  497. if isinstance(model, NaiveTransformer)
  498. else decode_one_token_ar_agent
  499. )
  500. next_token = prefill_decode(
  501. model,
  502. prompt,
  503. input_pos,
  504. semantic_ids=semantic_ids,
  505. **sampling_kwargs,
  506. ).view(num_samples, codebook_dim, -1)
  507. yield next_token.cpu()
  508. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  509. yield from decode_n_tokens_agent(
  510. model,
  511. next_token,
  512. input_pos,
  513. max_new_tokens - 1,
  514. im_end_id=im_end_id,
  515. semantic_ids=semantic_ids,
  516. decode_one_token=decode_one_token,
  517. early_stop_threshold=early_stop_threshold,
  518. **sampling_kwargs,
  519. )
  520. def encode_tokens(
  521. tokenizer,
  522. string,
  523. device="cuda",
  524. prompt_tokens=None,
  525. num_codebooks=4,
  526. ):
  527. string = clean_text(string)
  528. messages = []
  529. messages.append(
  530. Message(
  531. role="user",
  532. parts=[TextPart(text=string)],
  533. cal_loss=False,
  534. )
  535. )
  536. if prompt_tokens is not None:
  537. if prompt_tokens.ndim == 3:
  538. assert (
  539. prompt_tokens.shape[0] == 1
  540. ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
  541. prompt_tokens = prompt_tokens[0]
  542. assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
  543. if prompt_tokens.shape[0] > num_codebooks:
  544. logger.warning(
  545. f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
  546. )
  547. prompt_tokens = prompt_tokens[:num_codebooks]
  548. vq_part = VQPart(codes=prompt_tokens.to(device))
  549. messages.append(
  550. Message(
  551. role="assistant",
  552. parts=[TextPart(text="<|voice|>"), vq_part],
  553. cal_loss=False,
  554. )
  555. )
  556. else:
  557. messages.append(
  558. Message(
  559. role="assistant",
  560. parts=[TextPart(text="<|voice|>")],
  561. cal_loss=False,
  562. add_im_end=False,
  563. )
  564. )
  565. conversation = Conversation(messages=messages)
  566. # conversation.visualize(tokenizer)
  567. encoded = conversation.encode_for_inference(
  568. tokenizer=tokenizer,
  569. num_codebooks=num_codebooks,
  570. )
  571. return encoded.to(device)
  572. def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
  573. model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
  574. checkpoint_path, load_weights=True, is_agent=is_agent
  575. )
  576. model = model.to(device=device, dtype=precision)
  577. logger.info(f"Restored model from checkpoint")
  578. if isinstance(model, DualARTransformer):
  579. decode_one_token = (
  580. decode_one_token_ar_agent if is_agent else decode_one_token_ar
  581. )
  582. logger.info("Using DualARTransformer")
  583. else:
  584. decode_one_token = (
  585. decode_one_token_naive_agent if is_agent else decode_one_token_naive
  586. )
  587. logger.info("Using NaiveTransformer")
  588. if compile:
  589. logger.info("Compiling function...")
  590. decode_one_token = torch.compile(
  591. decode_one_token,
  592. fullgraph=True,
  593. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  594. mode="reduce-overhead" if torch.cuda.is_available() else None,
  595. )
  596. return model.eval(), decode_one_token
  597. @dataclass
  598. class GenerateResponse:
  599. action: Literal["sample", "next"]
  600. codes: Optional[torch.Tensor] = None
  601. text: Optional[str] = None
  602. def generate_long(
  603. *,
  604. model,
  605. device: str | torch.device,
  606. decode_one_token: callable,
  607. text: str,
  608. num_samples: int = 1,
  609. max_new_tokens: int = 0,
  610. top_p: int = 0.7,
  611. repetition_penalty: float = 1.5,
  612. temperature: float = 0.7,
  613. compile: bool = False,
  614. iterative_prompt: bool = True,
  615. max_length: int = 2048,
  616. chunk_length: int = 150,
  617. prompt_text: Optional[str | list[str]] = None,
  618. prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
  619. ):
  620. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  621. assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
  622. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  623. use_prompt = prompt_text is not None and prompt_tokens is not None
  624. if use_prompt and isinstance(prompt_text, str):
  625. prompt_text = [prompt_text]
  626. prompt_tokens = [prompt_tokens]
  627. assert use_prompt is False or len(prompt_text) == len(
  628. prompt_tokens
  629. ), "Prompt text and tokens must have the same length"
  630. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  631. tokenizer = model.tokenizer
  632. im_end_id = tokenizer.get_token_id("<|im_end|>")
  633. encoded = []
  634. texts = split_text(text, chunk_length) if iterative_prompt else [text]
  635. encoded_prompts = [
  636. Conversation(
  637. messages=[
  638. Message(
  639. role="system",
  640. parts=[TextPart(text="Speak out the provided text.")],
  641. cal_loss=False,
  642. )
  643. ]
  644. )
  645. .encode_for_inference(
  646. tokenizer=tokenizer,
  647. num_codebooks=model.config.num_codebooks,
  648. )
  649. .to(device)
  650. ]
  651. if use_prompt:
  652. for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
  653. encoded_prompts.append(
  654. encode_tokens(
  655. tokenizer,
  656. string=t,
  657. device=device,
  658. prompt_tokens=c,
  659. num_codebooks=model.config.num_codebooks,
  660. )
  661. )
  662. for idx, text in enumerate(texts):
  663. encoded.append(
  664. encode_tokens(
  665. tokenizer,
  666. string=text,
  667. device=device,
  668. num_codebooks=model.config.num_codebooks,
  669. )
  670. )
  671. logger.info(f"Encoded text: {text}")
  672. # Move temperature, top_p, repetition_penalty to device
  673. # This is important so that changing params doesn't trigger recompile
  674. temperature = torch.tensor(temperature, device=device, dtype=torch.float)
  675. top_p = torch.tensor(top_p, device=device, dtype=torch.float)
  676. repetition_penalty = torch.tensor(
  677. repetition_penalty, device=device, dtype=torch.float
  678. )
  679. for sample_idx in range(num_samples):
  680. if torch.cuda.is_available():
  681. torch.cuda.synchronize()
  682. global_encoded = []
  683. seg_idx = 0
  684. while seg_idx < len(encoded):
  685. logger.info(
  686. f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
  687. )
  688. seg = encoded[seg_idx]
  689. global_encoded.append(seg)
  690. lengths = reversed([seg.size(1) for seg in global_encoded])
  691. # Pick last 2000 tokens
  692. count = 0
  693. for i, length in enumerate(lengths):
  694. count += length
  695. if count + length > max_length - 1024 - sum(
  696. t.shape[1] for t in encoded_prompts
  697. ):
  698. break
  699. if i != 0 and i % 2 == 0:
  700. i -= 1
  701. # Rotate the list, always make sure first segment is included to avoid drift
  702. if i < len(global_encoded) - 2:
  703. partial_encoded = global_encoded[:2] + global_encoded[-i:]
  704. else:
  705. partial_encoded = global_encoded
  706. if use_prompt:
  707. partial_encoded = encoded_prompts + partial_encoded
  708. cat_encoded = torch.cat(partial_encoded, dim=1)
  709. prompt_length = cat_encoded.size(1)
  710. t0 = time.perf_counter()
  711. y = generate(
  712. model=model,
  713. prompt=cat_encoded,
  714. max_new_tokens=max_new_tokens,
  715. decode_one_token=decode_one_token,
  716. temperature=temperature,
  717. top_p=top_p,
  718. repetition_penalty=repetition_penalty,
  719. )
  720. if sample_idx == 0 and seg_idx == 0 and compile:
  721. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  722. if torch.cuda.is_available():
  723. torch.cuda.synchronize()
  724. t = time.perf_counter() - t0
  725. tokens_generated = y.size(1) - prompt_length
  726. tokens_sec = tokens_generated / t
  727. logger.info(
  728. f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  729. )
  730. logger.info(
  731. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  732. )
  733. if torch.cuda.is_available():
  734. logger.info(
  735. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  736. )
  737. # Put the generated tokens
  738. # since there is <im_end>, we remove last token
  739. codes = y[1:, prompt_length + 1 :].clone()
  740. assert (codes >= 0).all(), f"Negative code found"
  741. decoded = y[:, prompt_length:].clone()
  742. # But for global encoding, we should keep the <im_end> token
  743. global_encoded.append(decoded)
  744. assert (codes >= 0).all(), f"Negative code found: {codes}"
  745. yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
  746. seg_idx += 1
  747. # This indicates the end of the current sample
  748. yield GenerateResponse(action="next")
  749. @dataclass
  750. class WrappedGenerateResponse:
  751. status: Literal["success", "error"]
  752. response: Optional[GenerateResponse | Exception] = None
  753. @dataclass
  754. class GenerateRequest:
  755. request: dict
  756. response_queue: queue.Queue
  757. def launch_thread_safe_queue(
  758. checkpoint_path,
  759. device,
  760. precision,
  761. compile: bool = False,
  762. ):
  763. input_queue = queue.Queue()
  764. init_event = threading.Event()
  765. def worker():
  766. model, decode_one_token = load_model(
  767. checkpoint_path, device, precision, compile=compile
  768. )
  769. with torch.device(device):
  770. model.setup_caches(
  771. max_batch_size=1,
  772. max_seq_len=model.config.max_seq_len,
  773. dtype=next(model.parameters()).dtype,
  774. )
  775. init_event.set()
  776. while True:
  777. item: GenerateRequest | None = input_queue.get()
  778. if item is None:
  779. break
  780. kwargs = item.request
  781. response_queue = item.response_queue
  782. try:
  783. for chunk in generate_long(
  784. model=model, decode_one_token=decode_one_token, **kwargs
  785. ):
  786. response_queue.put(
  787. WrappedGenerateResponse(status="success", response=chunk)
  788. )
  789. except Exception as e:
  790. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  791. threading.Thread(target=worker, daemon=True).start()
  792. init_event.wait()
  793. return input_queue
  794. def launch_thread_safe_queue_agent(
  795. checkpoint_path,
  796. device,
  797. precision,
  798. compile: bool = False,
  799. ):
  800. input_queue = queue.Queue()
  801. init_event = threading.Event()
  802. tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
  803. config = BaseModelArgs.from_pretrained(checkpoint_path)
  804. def worker():
  805. model, decode_one_token = load_model(
  806. checkpoint_path, device, precision, compile=compile, is_agent=True
  807. )
  808. with torch.device(device):
  809. model.setup_caches(
  810. max_batch_size=1,
  811. max_seq_len=model.config.max_seq_len,
  812. dtype=next(model.parameters()).dtype,
  813. )
  814. init_event.set()
  815. while True:
  816. item: GenerateRequest | None = input_queue.get()
  817. if item is None:
  818. break
  819. kwargs = item.request
  820. response_queue = item.response_queue
  821. try:
  822. for token in generate_agent(
  823. model=model,
  824. decode_one_token=decode_one_token,
  825. **kwargs,
  826. ):
  827. response_queue.put(token)
  828. response_queue.put("stop")
  829. except Exception as e:
  830. import traceback
  831. logger.exception(f"Error in worker: {traceback.format_exc()}")
  832. response_queue.put("error")
  833. threading.Thread(target=worker, daemon=True).start()
  834. init_event.wait()
  835. return input_queue, tokenizer, config
  836. @click.command()
  837. @click.option(
  838. "--text",
  839. type=str,
  840. default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  841. )
  842. @click.option("--prompt-text", type=str, default=None, multiple=True)
  843. @click.option(
  844. "--prompt-tokens",
  845. type=click.Path(path_type=Path, exists=True),
  846. default=None,
  847. multiple=True,
  848. )
  849. @click.option("--num-samples", type=int, default=1)
  850. @click.option("--max-new-tokens", type=int, default=0)
  851. @click.option("--top-p", type=float, default=0.7)
  852. @click.option("--repetition-penalty", type=float, default=1.2)
  853. @click.option("--temperature", type=float, default=0.7)
  854. @click.option(
  855. "--checkpoint-path",
  856. type=click.Path(path_type=Path, exists=True),
  857. default="checkpoints/fish-speech-1.5",
  858. )
  859. @click.option("--device", type=str, default="cuda")
  860. @click.option("--compile/--no-compile", default=False)
  861. @click.option("--seed", type=int, default=42)
  862. @click.option("--half/--no-half", default=False)
  863. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  864. @click.option("--chunk-length", type=int, default=100)
  865. def main(
  866. text: str,
  867. prompt_text: Optional[list[str]],
  868. prompt_tokens: Optional[list[Path]],
  869. num_samples: int,
  870. max_new_tokens: int,
  871. top_p: int,
  872. repetition_penalty: float,
  873. temperature: float,
  874. checkpoint_path: Path,
  875. device: str,
  876. compile: bool,
  877. seed: int,
  878. half: bool,
  879. iterative_prompt: bool,
  880. chunk_length: int,
  881. ) -> None:
  882. precision = torch.half if half else torch.bfloat16
  883. if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
  884. raise ValueError(
  885. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  886. )
  887. logger.info("Loading model ...")
  888. t0 = time.time()
  889. model, decode_one_token = load_model(
  890. checkpoint_path, device, precision, compile=compile
  891. )
  892. with torch.device(device):
  893. model.setup_caches(
  894. max_batch_size=1,
  895. max_seq_len=model.config.max_seq_len,
  896. dtype=next(model.parameters()).dtype,
  897. )
  898. if torch.cuda.is_available():
  899. torch.cuda.synchronize()
  900. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  901. if prompt_tokens is not None:
  902. prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
  903. torch.manual_seed(seed)
  904. if torch.cuda.is_available():
  905. torch.cuda.manual_seed(seed)
  906. generator = generate_long(
  907. model=model,
  908. device=device,
  909. decode_one_token=decode_one_token,
  910. text=text,
  911. num_samples=num_samples,
  912. max_new_tokens=max_new_tokens,
  913. top_p=top_p,
  914. repetition_penalty=repetition_penalty,
  915. temperature=temperature,
  916. compile=compile,
  917. iterative_prompt=iterative_prompt,
  918. chunk_length=chunk_length,
  919. prompt_text=prompt_text,
  920. prompt_tokens=prompt_tokens,
  921. )
  922. idx = 0
  923. codes = []
  924. for response in generator:
  925. if response.action == "sample":
  926. codes.append(response.codes)
  927. logger.info(f"Sampled text: {response.text}")
  928. elif response.action == "next":
  929. if codes:
  930. np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
  931. logger.info(f"Saved codes to codes_{idx}.npy")
  932. logger.info(f"Next sample")
  933. codes = []
  934. idx += 1
  935. else:
  936. logger.error(f"Error: {response}")
  937. if __name__ == "__main__":
  938. main()