inference.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146
  1. import os
  2. import queue
  3. import re
  4. import threading
  5. import time
  6. import traceback
  7. from copy import deepcopy
  8. from dataclasses import dataclass
  9. from pathlib import Path
  10. from typing import Callable, Literal, Optional, Tuple, Union, Any
  11. import click
  12. import numpy as np
  13. import torch._inductor.config
  14. from loguru import logger
  15. from tqdm import tqdm
  16. from fish_speech.content_sequence import (
  17. TextPart,
  18. VQPart,
  19. )
  20. from fish_speech.conversation import Conversation, Message
  21. from fish_speech.tokenizer import IM_END_TOKEN
  22. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  23. torch._inductor.config.coordinate_descent_tuning = True
  24. torch._inductor.config.triton.unique_kernel_names = True
  25. if hasattr(torch._inductor.config, "fx_graph_cache"):
  26. torch._inductor.config.fx_graph_cache = True
  27. from torch.nn.attention import SDPBackend, sdpa_kernel
  28. from fish_speech.models.text2semantic.llama import (
  29. DualARTransformer,
  30. )
  31. def multinomial_sample_one_no_sync(probs_sort):
  32. q = torch.rand_like(probs_sort)
  33. q = -torch.log(q)
  34. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  35. RAS_WIN_SIZE = 10 # window for Repetition Aware Sampling
  36. RAS_HIGH_TEMP = 1.0
  37. RAS_HIGH_TOP_P = 0.9
  38. def logits_to_probs(
  39. logits,
  40. temperature: torch.Tensor,
  41. top_p: torch.Tensor,
  42. top_k: int, # 注意: 我看到你传进来的是 int,这很关键
  43. ) -> torch.Tensor:
  44. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  45. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  46. indices = torch.arange(sorted_logits.shape[-1], device=sorted_logits.device)
  47. top_k_mask = indices >= top_k
  48. sorted_indices_to_remove = (cum_probs > top_p) | top_k_mask
  49. sorted_indices_to_remove[0] = False # 单元素修改问题不大,或者写成 | (indices != 0)
  50. indices_to_remove = sorted_indices_to_remove.scatter(
  51. dim=-1, index=sorted_indices, src=sorted_indices_to_remove
  52. )
  53. logits = torch.where(
  54. indices_to_remove, float("-Inf"), logits
  55. ) # 同样替换 masked_fill_ 为 torch.where
  56. logits = logits / torch.clip(temperature, min=1e-5)
  57. probs = torch.nn.functional.softmax(logits, dim=-1)
  58. return probs
  59. def sample(
  60. logits,
  61. temperature: torch.Tensor,
  62. top_p: torch.Tensor,
  63. top_k: int,
  64. ) -> Tuple[torch.Tensor, torch.Tensor]:
  65. probs = logits_to_probs(
  66. logits=logits[0, -1],
  67. temperature=temperature,
  68. top_p=top_p,
  69. top_k=top_k,
  70. )
  71. idx_next = multinomial_sample_one_no_sync(probs)
  72. return idx_next, probs
  73. def decode_one_token_ar(
  74. model: DualARTransformer,
  75. x: torch.Tensor,
  76. input_pos: torch.Tensor,
  77. temperature: torch.Tensor,
  78. top_p: torch.Tensor,
  79. top_k: int,
  80. semantic_logit_bias: torch.Tensor,
  81. audio_masks: torch.Tensor,
  82. audio_parts: torch.Tensor,
  83. previous_tokens: Optional[torch.Tensor] = None,
  84. ) -> torch.Tensor:
  85. forward_result = model.forward_generate(
  86. x,
  87. input_pos,
  88. audio_masks=audio_masks,
  89. audio_parts=audio_parts,
  90. )
  91. logits = forward_result.logits # (1, 1, vocab_size)
  92. hidden_states = forward_result.hidden_states
  93. # Apply constrained decoding: only allow semantic tokens + im_end
  94. biased_logits = logits + semantic_logit_bias
  95. # Normal sample
  96. main_token_normal = sample(
  97. biased_logits, temperature=temperature, top_p=top_p, top_k=top_k
  98. )[0]
  99. # RAS: also sample with high temp to use as fallback if token repeats
  100. high_temp = torch.tensor(
  101. RAS_HIGH_TEMP, device=temperature.device, dtype=temperature.dtype
  102. )
  103. high_top_p = torch.tensor(RAS_HIGH_TOP_P, device=top_p.device, dtype=top_p.dtype)
  104. main_token_high = sample(
  105. biased_logits, temperature=high_temp, top_p=high_top_p, top_k=top_k
  106. )[0]
  107. # Use high-temp sample if: token is semantic AND token is in previous window
  108. if previous_tokens is not None:
  109. in_window = (previous_tokens[0] == main_token_normal).any()
  110. # Use tensor ops (&, torch.where) instead of Python (and, if) — torch.compile requires no data-dependent branching
  111. is_semantic = (main_token_normal >= model.config.semantic_begin_id) & (
  112. main_token_normal <= model.config.semantic_end_id
  113. )
  114. should_use_high = in_window & is_semantic
  115. main_token_normal = torch.where(
  116. should_use_high, main_token_high, main_token_normal
  117. )
  118. codebooks = [main_token_normal]
  119. input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
  120. model.forward_generate_fast(hidden_states, input_pos)
  121. a = codebooks[0] - model.config.semantic_begin_id
  122. a = torch.clamp(a, min=0, max=model.config.codebook_size - 1)
  123. hidden_states = model.fast_embeddings(a)
  124. codebooks.append(a)
  125. for codebook_idx in range(1, model.config.num_codebooks):
  126. input_pos = torch.tensor(
  127. [codebook_idx], device=hidden_states.device, dtype=torch.long
  128. )
  129. logits = model.forward_generate_fast(hidden_states, input_pos)
  130. short_logits = logits # DualAR predicts config.codebook_size number of tokens
  131. # Convert logits to probs (no constrain for fast codebooks)
  132. a = sample(
  133. short_logits,
  134. temperature=temperature,
  135. top_p=top_p,
  136. top_k=top_k,
  137. )[0]
  138. hidden_states = model.fast_embeddings(a)
  139. codebooks.append(a)
  140. codebooks = torch.stack(codebooks, dim=1)
  141. # Only delete references, let Python GC handle cleanup
  142. del logits, hidden_states, forward_result
  143. return codebooks.T
  144. def decode_n_tokens(
  145. model: DualARTransformer,
  146. cur_token: torch.Tensor,
  147. input_pos: torch.Tensor,
  148. num_new_tokens: int,
  149. temperature: torch.Tensor,
  150. top_p: torch.Tensor,
  151. top_k: int,
  152. semantic_logit_bias: torch.Tensor,
  153. audio_masks: torch.Tensor,
  154. audio_parts: torch.Tensor,
  155. decode_one_token=decode_one_token_ar,
  156. ):
  157. start = time.perf_counter()
  158. # Rolling window for RAS (Repetition Aware Sampling)
  159. previous_tokens = torch.zeros(
  160. (model.config.num_codebooks + 1, RAS_WIN_SIZE),
  161. dtype=torch.int,
  162. device=cur_token.device,
  163. )
  164. step1 = time.perf_counter()
  165. # Accumulate all generated tokens (the actual output)
  166. new_tokens = []
  167. # [MODIFIED] Pre-fetch ID for efficiency loop
  168. im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
  169. step2 = time.perf_counter()
  170. for i in tqdm(range(num_new_tokens)):
  171. f_start = time.perf_counter()
  172. with sdpa_kernel(SDPBackend.MATH):
  173. next_token = decode_one_token(
  174. model=model,
  175. x=cur_token,
  176. input_pos=input_pos,
  177. previous_tokens=previous_tokens,
  178. temperature=temperature,
  179. top_p=top_p,
  180. top_k=top_k,
  181. semantic_logit_bias=semantic_logit_bias,
  182. audio_masks=audio_masks,
  183. audio_parts=audio_parts,
  184. ).clone()
  185. input_pos += 1
  186. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  187. # Roll RAS window left and insert new token at end
  188. previous_tokens = previous_tokens.roll(-1, dims=1)
  189. previous_tokens[:, -1] = next_token.view(model.config.num_codebooks + 1, -1)[
  190. :, 0
  191. ]
  192. new_tokens.append(next_token)
  193. f_end = time.perf_counter()
  194. logger.info(f"num_new_tokens for elapse: {f_end - f_start}")
  195. if cur_token[0, 0, -1] == im_end_id:
  196. break
  197. step3 = time.perf_counter()
  198. del cur_token
  199. logger.info(f"elapse step1: {step1 - start}, step2: {step2 - step1}, step3: {step3 - step2}")
  200. return torch.cat(new_tokens, dim=1)
  201. def decode_n_tokens_optimized(
  202. model: DualARTransformer,
  203. cur_token: torch.Tensor,
  204. input_pos: torch.Tensor,
  205. num_new_tokens: int,
  206. temperature: torch.Tensor,
  207. top_p: torch.Tensor,
  208. top_k: int,
  209. semantic_logit_bias: torch.Tensor,
  210. audio_masks: torch.Tensor,
  211. audio_parts: torch.Tensor,
  212. previous_tokens: torch.Tensor,
  213. im_end_id: Any,
  214. decode_one_token=decode_one_token_ar,
  215. ):
  216. """
  217. Optimized version:
  218. - no roll (ring buffer)
  219. - flash attention
  220. - reduced view/reshape
  221. """
  222. device = cur_token.device
  223. num_streams = model.config.num_codebooks + 1
  224. # =========================
  225. # 1. ring buffer index (替代 roll)
  226. # =========================
  227. history_len = previous_tokens.size(1)
  228. write_idx = history_len - 1
  229. new_tokens = []
  230. # =========================
  231. # 2. precompute reshape shape
  232. # =========================
  233. batch = 1
  234. # =========================
  235. # 3. main loop
  236. # =========================
  237. for i in range(num_new_tokens):
  238. # ⚡ use flash attention (重要优化)
  239. with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
  240. next_token = decode_one_token(
  241. model=model,
  242. x=cur_token,
  243. input_pos=input_pos,
  244. previous_tokens=previous_tokens,
  245. temperature=temperature,
  246. top_p=top_p,
  247. top_k=top_k,
  248. semantic_logit_bias=semantic_logit_bias,
  249. audio_masks=audio_masks,
  250. audio_parts=audio_parts,
  251. ).clone()
  252. # =========================
  253. # 4. update position
  254. # =========================
  255. input_pos += 1
  256. # =========================
  257. # 5. reshape once (reuse view logic)
  258. # =========================
  259. next_token_2d = next_token.view(num_streams, -1)
  260. cur_token = next_token_2d.unsqueeze(0)
  261. # =========================
  262. # 6. ring buffer update (NO roll)
  263. # =========================
  264. previous_tokens[:, write_idx] = next_token_2d[:, 0]
  265. write_idx = (write_idx + 1) % history_len
  266. # =========================
  267. # 7. store output
  268. # =========================
  269. new_tokens.append(next_token)
  270. # =========================
  271. # 8. EOS check
  272. # =========================
  273. if cur_token[0, 0, -1] == im_end_id:
  274. break
  275. return new_tokens
  276. @torch.no_grad()
  277. @torch.inference_mode()
  278. def generate(
  279. *,
  280. model: DualARTransformer,
  281. prompt: torch.Tensor,
  282. max_new_tokens: int,
  283. audio_masks: torch.Tensor,
  284. audio_parts: torch.Tensor,
  285. prompt_tokens = None,
  286. decode_one_token=decode_one_token_ar,
  287. num_samples: int = 1,
  288. **sampling_kwargs,
  289. ):
  290. """
  291. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  292. """
  293. # create an empty tensor of the expected final shape and fill in the current tokens
  294. start = time.perf_counter()
  295. T = prompt.size(1)
  296. prompt = prompt[None].repeat(num_samples, 1, 1)
  297. if T >= model.config.max_seq_len:
  298. raise ValueError(
  299. f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
  300. )
  301. if max_new_tokens:
  302. if T + max_new_tokens > model.config.max_seq_len:
  303. max_new_tokens = model.config.max_seq_len - T
  304. T_new = T + max_new_tokens
  305. else:
  306. T_new = model.config.max_seq_len
  307. max_new_tokens = T_new - T
  308. device = prompt.device
  309. dtype = next(
  310. model.parameters()
  311. ).dtype # model weight dtype (bfloat16), NOT prompt dtype (int32)
  312. step1 = time.perf_counter()
  313. # Critical fix: Only set up cache on first run or when necessary
  314. if not hasattr(model, "_cache_setup_done") or not model._cache_setup_done:
  315. with torch.device(device):
  316. model.setup_caches(
  317. max_batch_size=1, # Fixed to 1, avoid dynamic changes
  318. max_seq_len=model.config.max_seq_len,
  319. dtype=next(model.parameters()).dtype,
  320. )
  321. model._cache_setup_done = True
  322. step2 = time.perf_counter()
  323. codebook_dim = 1 + model.config.num_codebooks
  324. # Create new tensor each time, but try to reuse memory
  325. input_pos = torch.arange(0, T, device=device, dtype=torch.long)
  326. empty = torch.empty(
  327. (codebook_dim, model.config.max_seq_len), dtype=prompt.dtype, device=device
  328. )
  329. step3 = time.perf_counter()
  330. empty[:, :T] = prompt
  331. seq = empty
  332. temp_val = sampling_kwargs.get("temperature", 1.0)
  333. top_p_val = sampling_kwargs.get("top_p", 0.9)
  334. top_k_val = sampling_kwargs.get("top_k", 30)
  335. temperature = torch.tensor(temp_val, device=device, dtype=dtype)
  336. step4 = time.perf_counter()
  337. top_p = torch.tensor(top_p_val, device=device, dtype=dtype)
  338. step5 = time.perf_counter()
  339. # Build semantic logit bias: 0 for semantic tokens + im_end, -inf for all others
  340. vocab_size = model.config.vocab_size
  341. semantic_logit_bias = torch.full(
  342. (1, 1, vocab_size), float("-inf"), device=device, dtype=dtype
  343. )
  344. step6 = time.perf_counter()
  345. # [MODIFIED] Use config for semantic range
  346. semantic_logit_bias[
  347. 0, 0, model.config.semantic_begin_id: model.config.semantic_end_id + 1
  348. ] = 0.0
  349. # [MODIFIED] Use tokenizer.get_token_id (Wrapper method)
  350. semantic_logit_bias[0, 0, model.tokenizer.get_token_id(IM_END_TOKEN)] = 0.0
  351. step7 = time.perf_counter()
  352. prefill_decode = decode_one_token_ar
  353. first_token = prefill_decode(
  354. model,
  355. prompt.view(1, codebook_dim, -1),
  356. input_pos,
  357. temperature,
  358. top_p,
  359. top_k_val,
  360. semantic_logit_bias,
  361. audio_masks,
  362. audio_parts,
  363. )
  364. seq[:, T: T + 1] = first_token
  365. step8 = time.perf_counter()
  366. # Recreate input_pos
  367. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  368. step9 = time.perf_counter()
  369. im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
  370. codebook_dim = 1 + model.config.num_codebooks
  371. window_size = 64
  372. previous_tokens = torch.zeros(
  373. (1, window_size, codebook_dim),
  374. device=device,
  375. dtype=first_token.dtype,
  376. )
  377. # =========================
  378. # 1. warm start prompt
  379. # =========================
  380. if prompt_tokens is not None:
  381. # 确保 shape = [B, T, C]
  382. if prompt_tokens.dim() == 2:
  383. prompt_tokens = prompt_tokens.unsqueeze(0)
  384. T = min(prompt_tokens.size(1), window_size)
  385. previous_tokens[:, -T:] = prompt_tokens[:, -T:]
  386. # =========================
  387. # 2. insert first token
  388. # =========================
  389. previous_tokens[:, -1, :] = first_token.view(codebook_dim)
  390. x = decode_n_tokens_optimized(
  391. model,
  392. first_token.view(1, codebook_dim, -1),
  393. input_pos,
  394. max_new_tokens - 1,
  395. temperature=temperature,
  396. top_p=top_p,
  397. top_k=top_k_val,
  398. semantic_logit_bias=semantic_logit_bias,
  399. audio_masks=audio_masks,
  400. audio_parts=audio_parts,
  401. im_end_id=im_end_id,
  402. previous_tokens=previous_tokens,
  403. decode_one_token=decode_one_token,
  404. )
  405. seq = seq[:, : T + 1 + x.size(1)]
  406. seq[:, T + 1:] = x
  407. step10 = time.perf_counter()
  408. # Clean up temporary variables
  409. del first_token, x, prompt, empty, input_pos
  410. step11 = time.perf_counter()
  411. logger.info(f"elapse "
  412. f"step1: {step1 - start}, step2: {step2 - step1}, step3: {step3 - step2} "
  413. f"step4: {step4 - step3}, step5: {step5 - step4} step6: {step6 - step5} "
  414. f"step7: {step7 - step6} step8: {step8 - step7} step9: {step9 - step8} "
  415. f"step10: {step10 - step9} step11: {step11 - step10} ")
  416. return seq
  417. def init_model(checkpoint_path, device, precision, compile=False):
  418. model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
  419. logger.info(f"precision: {precision.__class__.__name__}")
  420. model = model.to(device=device, dtype=precision)
  421. logger.info(f"Restored model from checkpoint")
  422. if isinstance(model, DualARTransformer):
  423. decode_one_token = decode_one_token_ar
  424. # prefill_n_tokens = decode_one_token_ar
  425. logger.info("Using DualARTransformer")
  426. else:
  427. raise ValueError("Unsupported model type")
  428. # Pre-create fixed parameter tensors to avoid runtime creation
  429. model.fixed_temperature = torch.tensor(0.7, device=device, dtype=torch.float)
  430. model.fixed_top_p = torch.tensor(0.7, device=device, dtype=torch.float)
  431. model.fixed_repetition_penalty = torch.tensor(1.5, device=device, dtype=torch.float)
  432. # Mark whether cache has been initialized
  433. model._cache_setup_done = False
  434. if compile:
  435. logger.info("Compiling function...")
  436. decode_one_token = torch.compile(
  437. decode_one_token,
  438. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  439. mode="default" if torch.cuda.is_available() else None,
  440. fullgraph=True,
  441. )
  442. return model.eval(), decode_one_token
  443. @torch.inference_mode()
  444. def load_codec_model(codec_checkpoint_path, device, precision=torch.bfloat16):
  445. """Load the DAC codec model for audio encoding/decoding."""
  446. from hydra.utils import instantiate
  447. from omegaconf import OmegaConf
  448. config_path = Path(__file__).parent.parent.parent / "configs" / "modded_dac_vq.yaml"
  449. cfg = OmegaConf.load(str(config_path))
  450. codec = instantiate(cfg)
  451. state_dict = torch.load(codec_checkpoint_path, map_location="cpu")
  452. if "state_dict" in state_dict:
  453. state_dict = state_dict["state_dict"]
  454. if any("generator" in k for k in state_dict):
  455. state_dict = {
  456. k.replace("generator.", ""): v
  457. for k, v in state_dict.items()
  458. if "generator." in k
  459. }
  460. codec.load_state_dict(state_dict, strict=False)
  461. codec.eval()
  462. codec.to(device=device, dtype=precision)
  463. return codec
  464. @torch.inference_mode()
  465. def encode_audio(audio_path, codec, device):
  466. """Encode an audio file to VQ codes."""
  467. import torchaudio
  468. wav, sr = torchaudio.load(str(audio_path))
  469. if wav.shape[0] > 1:
  470. wav = wav.mean(dim=0, keepdim=True)
  471. wav = torchaudio.functional.resample(wav.to(device), sr, codec.sample_rate)[0]
  472. # Match codec model dtype (e.g. bfloat16)
  473. model_dtype = next(codec.parameters()).dtype
  474. audios = wav[None, None].to(dtype=model_dtype) # (1, 1, T)
  475. audio_lengths = torch.tensor([len(wav)], device=device, dtype=torch.long)
  476. indices, feature_lengths = codec.encode(audios, audio_lengths)
  477. return indices[0, :, : feature_lengths[0]] # (num_codebooks, T)
  478. @torch.inference_mode()
  479. def decode_to_audio(codes, codec):
  480. """Decode VQ codes to audio waveform."""
  481. # codes: (num_codebooks, T) -> (1, num_codebooks, T)
  482. audio = codec.from_indices(codes[None])
  483. return audio[0, 0] # (T,) mono waveform
  484. @dataclass
  485. class GenerateResponse:
  486. action: Literal["sample", "next"]
  487. codes: Optional[torch.Tensor] = None
  488. text: Optional[str] = None
  489. def split_text_by_speaker(text: str) -> list[str]:
  490. """
  491. Split text into turns based on <|speaker:X|> tags.
  492. Args:
  493. text: The full text with speaker tags
  494. Returns:
  495. List of speaker turns, each starting with <|speaker:X|>
  496. """
  497. pattern = r"(<\|speaker:\d+\|>)"
  498. parts = re.split(pattern, text)
  499. turns = []
  500. i = 0
  501. while i < len(parts):
  502. part = parts[i].strip()
  503. if re.match(pattern, part):
  504. if i + 1 < len(parts):
  505. turn = part + parts[i + 1]
  506. turns.append(turn.strip())
  507. i += 2
  508. else:
  509. turns.append(part)
  510. i += 1
  511. else:
  512. i += 1
  513. return turns
  514. def group_turns_into_batches(
  515. turns: list[str], max_speakers: int = 3, max_bytes: int = 300
  516. ) -> list[str]:
  517. """
  518. Group turns into batches based on speaker count or byte limit.
  519. Args:
  520. turns: List of speaker turns
  521. max_speakers: Maximum number of speakers per batch (default 3)
  522. max_bytes: Maximum UTF-8 bytes per batch (default 300)
  523. Returns:
  524. List of batched text strings
  525. """
  526. batches = []
  527. current_batch = []
  528. current_bytes = 0
  529. for turn in turns:
  530. turn_bytes = len(turn.encode("utf-8"))
  531. would_exceed_speakers = len(current_batch) >= max_speakers
  532. would_exceed_bytes = current_bytes + turn_bytes > max_bytes and current_batch
  533. if would_exceed_speakers or would_exceed_bytes:
  534. batches.append("\n".join(current_batch))
  535. current_batch = [turn]
  536. current_bytes = turn_bytes
  537. else:
  538. current_batch.append(turn)
  539. current_bytes += turn_bytes
  540. if current_batch:
  541. batches.append("\n".join(current_batch))
  542. return batches
  543. def generate_long(
  544. *,
  545. model,
  546. device: Union[str, torch.device],
  547. decode_one_token: Callable,
  548. text: str,
  549. num_samples: int = 1,
  550. max_new_tokens: int = 0,
  551. top_p: float = 0.9,
  552. top_k: int = 30,
  553. repetition_penalty: float = 1.1,
  554. temperature: float = 1.0,
  555. compile: bool = False,
  556. iterative_prompt: bool = True,
  557. chunk_length: int = 512,
  558. prompt_text: Optional[Union[str, list[str]]] = None,
  559. prompt_tokens: Optional[Union[torch.Tensor, list[torch.Tensor]]] = None,
  560. ):
  561. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  562. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  563. logger.info(f"generate_long.param.device: {device}")
  564. logger.info(f"generate_long.param.text: {text}")
  565. logger.info(f"generate_long.param.max_new_tokens: {max_new_tokens}")
  566. logger.info(f"generate_long.param.top_p: {top_p}")
  567. logger.info(f"generate_long.param.top_k: {top_k}")
  568. logger.info(f"generate_long.param.temperature: {temperature}")
  569. logger.info(f"generate_long.param.compile: {compile}")
  570. logger.info(f"generate_long.param.chunk_length: {chunk_length}")
  571. logger.info(f"generate_long.param.prompt_text: {prompt_text}")
  572. logger.info(f"generate_long.param.prompt_tokens: {prompt_tokens}")
  573. use_prompt = bool(prompt_text) and bool(prompt_tokens)
  574. if use_prompt and isinstance(prompt_text, str):
  575. prompt_text = [prompt_text]
  576. prompt_tokens = [prompt_tokens]
  577. if use_prompt:
  578. assert len(prompt_text) == len(
  579. prompt_tokens
  580. ), "Prompt text and tokens must have the same length"
  581. if prompt_tokens:
  582. prompt_tokens = [i.cpu() for i in prompt_tokens]
  583. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  584. tokenizer = model.tokenizer
  585. max_length = model.config.max_seq_len
  586. # Build base conversation with system message
  587. base_conversation = Conversation()
  588. all_codes = None
  589. if use_prompt:
  590. # Auto-add speaker tags to prompt texts that don't have them
  591. tagged_prompt_text = []
  592. for i, t in enumerate(prompt_text):
  593. if not re.search(r"<\|speaker:\d+\|>", t):
  594. tagged_prompt_text.append(f"<|speaker:{i}|>{t}")
  595. else:
  596. tagged_prompt_text.append(t)
  597. system_parts = [
  598. TextPart(
  599. text="convert the provided text to speech reference to the following:\n\nText:\n",
  600. cal_loss=False,
  601. ),
  602. ]
  603. reference_text = "\n".join(tagged_prompt_text)
  604. system_parts.append(TextPart(text=reference_text, cal_loss=False))
  605. system_parts.append(TextPart(text="\n\nSpeech:\n", cal_loss=False))
  606. all_codes = torch.cat([c for c in prompt_tokens], dim=1)
  607. system_parts.append(VQPart(codes=all_codes, cal_loss=False))
  608. # torch.save(all_codes, "debug_vq_codes.pt")
  609. else:
  610. system_parts = [
  611. TextPart(text="convert the provided text to speech", cal_loss=False)
  612. ]
  613. base_conversation.append(
  614. Message(
  615. role="system",
  616. parts=system_parts,
  617. cal_loss=False,
  618. add_im_start=True,
  619. add_im_end=True,
  620. )
  621. )
  622. # Split text by speaker and group into batches
  623. turns = split_text_by_speaker(text)
  624. if turns:
  625. batches = group_turns_into_batches(
  626. turns, max_speakers=5, max_bytes=chunk_length
  627. )
  628. else:
  629. batches = [text]
  630. logger.info(f"Split into {len(turns)} turns, grouped into {len(batches)} batches")
  631. for sample_idx in range(num_samples):
  632. if torch.cuda.is_available():
  633. torch.cuda.synchronize()
  634. t0 = time.perf_counter()
  635. # Deep copy base conversation for this sample
  636. conversation = deepcopy(base_conversation)
  637. for batch_idx, batch_text in enumerate(batches):
  638. logger.info(
  639. f"--- Sample {sample_idx}, Batch {batch_idx} "
  640. f"({len(batch_text.encode('utf-8'))} bytes) ---"
  641. )
  642. logger.info(f"Batch text: {batch_text}")
  643. # Add user message
  644. conversation.append(
  645. Message(
  646. role="user",
  647. parts=[TextPart(text=batch_text, cal_loss=False)],
  648. cal_loss=False,
  649. add_im_start=True,
  650. add_im_end=True,
  651. )
  652. )
  653. # Deep copy for generation (don't pollute original conversation)
  654. conversation_gen = deepcopy(conversation)
  655. conversation_gen.append(
  656. Message(
  657. role="assistant",
  658. parts=[],
  659. cal_loss=False,
  660. modality="voice",
  661. add_im_start=True,
  662. add_im_end=False,
  663. )
  664. )
  665. logger.info("Visualizing prompt structure:")
  666. conversation_gen.visualize(
  667. tokenizer,
  668. merge_audio_tokens=True,
  669. merge_semantic_tokens=True,
  670. )
  671. encoded, audio_masks, audio_parts = conversation_gen.encode_for_inference(
  672. tokenizer, num_codebooks=model.config.num_codebooks
  673. )
  674. logger.info(f"Encoded prompt shape: {encoded.shape}")
  675. if audio_parts is not None:
  676. logger.info(f"Audio parts shape: {audio_parts.shape}")
  677. if audio_masks is not None:
  678. logger.info(
  679. f"Audio masks non-zero count: {torch.count_nonzero(audio_masks)}"
  680. )
  681. if encoded.size(1) > max_length - 2048:
  682. raise ValueError(
  683. f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}"
  684. )
  685. encoded = encoded.to(device=device)
  686. prompt_length = encoded.size(1)
  687. y = generate(
  688. model=model,
  689. prompt=encoded,
  690. max_new_tokens=max_new_tokens,
  691. audio_masks=audio_masks,
  692. audio_parts=audio_parts,
  693. decode_one_token=decode_one_token,
  694. temperature=temperature,
  695. top_p=top_p,
  696. top_k=top_k,
  697. prompt_tokens=all_codes,
  698. )
  699. if sample_idx == 0 and batch_idx == 0 and compile:
  700. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  701. if torch.cuda.is_available():
  702. torch.cuda.synchronize()
  703. t_batch = time.perf_counter() - t0
  704. tokens_generated = y.size(1) - prompt_length
  705. tokens_sec = tokens_generated / t_batch if t_batch > 0 else 0
  706. logger.info(
  707. f"Batch {batch_idx}: Generated {tokens_generated} tokens in "
  708. f"{t_batch:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  709. )
  710. logger.info(
  711. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  712. )
  713. # Extract generated codes
  714. codes = y[1:, prompt_length:-1].clone()
  715. assert (codes >= 0).all(), f"Negative code found: {codes}"
  716. # Add assistant message with generated codes back to conversation
  717. conversation.append(
  718. Message(
  719. role="assistant",
  720. parts=[VQPart(codes=codes.cpu(), cal_loss=False)],
  721. cal_loss=False,
  722. modality="voice",
  723. add_im_start=True,
  724. add_im_end=True,
  725. )
  726. )
  727. yield GenerateResponse(action="sample", codes=codes, text=batch_text)
  728. MAX_HISTORY_TURNS = 2 # 只保留最近 2 轮 user/assistant
  729. assistant_indices = [i for i, m in enumerate(conversation.messages) if m.role == "assistant"]
  730. if len(assistant_indices) > MAX_HISTORY_TURNS:
  731. drop = assistant_indices[0]
  732. # 移除最早的 user+assistant 对,保留 system 消息
  733. conversation = Conversation([m for i, m in enumerate(conversation.messages)
  734. if i not in (drop - 1, drop)])
  735. # Cleanup
  736. del y, encoded
  737. if torch.cuda.is_available():
  738. torch.cuda.empty_cache()
  739. import gc
  740. gc.collect()
  741. if torch.cuda.is_available():
  742. logger.info(
  743. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  744. )
  745. yield GenerateResponse(action="next")
  746. @dataclass
  747. class WrappedGenerateResponse:
  748. status: Literal["success", "error"]
  749. response: Optional[Union[GenerateResponse, Exception]] = None
  750. @dataclass
  751. class GenerateRequest:
  752. request: dict
  753. response_queue: queue.Queue
  754. def launch_thread_safe_queue(
  755. checkpoint_path,
  756. device,
  757. precision,
  758. compile: bool = False,
  759. ):
  760. input_queue = queue.Queue()
  761. init_event = threading.Event()
  762. def worker():
  763. model, decode_one_token = init_model(
  764. checkpoint_path, device, precision, compile=compile
  765. )
  766. with torch.device(device):
  767. model.setup_caches(
  768. max_batch_size=1,
  769. max_seq_len=model.config.max_seq_len,
  770. dtype=next(model.parameters()).dtype,
  771. )
  772. init_event.set()
  773. while True:
  774. item: GenerateRequest | None = input_queue.get()
  775. if item is None:
  776. break
  777. kwargs = item.request
  778. response_queue = item.response_queue
  779. try:
  780. for chunk in generate_long(
  781. model=model, decode_one_token=decode_one_token, **kwargs
  782. ):
  783. response_queue.put(
  784. WrappedGenerateResponse(status="success", response=chunk)
  785. )
  786. # Only clear cache after complete request batch
  787. if torch.cuda.is_available():
  788. torch.cuda.empty_cache()
  789. except Exception as e:
  790. logger.error(traceback.format_exc())
  791. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  792. # Clear cache on error
  793. if torch.cuda.is_available():
  794. torch.cuda.empty_cache()
  795. threading.Thread(target=worker, daemon=True).start()
  796. init_event.wait()
  797. return input_queue
  798. # ============================================
  799. # =============== 原始代码 =================
  800. # ============================================
  801. @click.command()
  802. @click.option(
  803. "--text",
  804. type=str,
  805. default="<|speaker:0|>你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  806. )
  807. @click.option("--prompt-text", type=str, default=None, multiple=True)
  808. @click.option(
  809. "--prompt-tokens",
  810. type=click.Path(path_type=Path, exists=True),
  811. default=None,
  812. multiple=True,
  813. )
  814. @click.option(
  815. "--prompt-audio",
  816. type=click.Path(path_type=Path, exists=True),
  817. default=None,
  818. multiple=True,
  819. )
  820. @click.option("--output", type=click.Path(path_type=Path), default=None)
  821. @click.option("--num-samples", type=int, default=1)
  822. @click.option("--max-new-tokens", type=int, default=0)
  823. @click.option("--top-p", type=float, default=0.9)
  824. @click.option("--top-k", type=int, default=30)
  825. @click.option("--temperature", type=float, default=1.0)
  826. @click.option(
  827. "--checkpoint-path",
  828. type=click.Path(path_type=Path, exists=True),
  829. default="checkpoints/s2-pro",
  830. )
  831. @click.option("--device", type=str, default="cuda")
  832. @click.option("--compile/--no-compile", default=False)
  833. @click.option("--seed", type=int, default=42)
  834. @click.option("--half/--no-half", default=False)
  835. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  836. @click.option("--chunk-length", type=int, default=300)
  837. @click.option("--output-dir", type=Path, default="output")
  838. def main(
  839. text: str,
  840. prompt_text: Optional[tuple[str, ...]],
  841. prompt_tokens: Optional[tuple[Path, ...]],
  842. prompt_audio: Optional[tuple[Path, ...]],
  843. output: Optional[Path],
  844. num_samples: int,
  845. max_new_tokens: int,
  846. top_p: float,
  847. top_k: int,
  848. temperature: float,
  849. checkpoint_path: Path,
  850. device: str,
  851. compile: bool,
  852. seed: int,
  853. half: bool,
  854. iterative_prompt: bool,
  855. chunk_length: int,
  856. output_dir: Path,
  857. ) -> None:
  858. os.makedirs(output_dir, exist_ok=True)
  859. precision = torch.half if half else torch.bfloat16
  860. if prompt_text and not prompt_audio and not prompt_tokens:
  861. raise ValueError(
  862. "--prompt-text requires either --prompt-audio or --prompt-tokens"
  863. )
  864. if prompt_text and prompt_tokens and len(prompt_text) != len(prompt_tokens):
  865. raise ValueError(
  866. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  867. )
  868. if prompt_text and prompt_audio and len(prompt_text) != len(prompt_audio):
  869. raise ValueError(
  870. f"Number of prompt text ({len(prompt_text)}) and prompt audio ({len(prompt_audio)}) should be the same"
  871. )
  872. logger.info("Loading model ...")
  873. t0 = time.time()
  874. model, decode_one_token = init_model(
  875. checkpoint_path, device, precision, compile=compile
  876. )
  877. with torch.device(device):
  878. model.setup_caches(
  879. max_batch_size=1,
  880. max_seq_len=model.config.max_seq_len,
  881. dtype=next(model.parameters()).dtype,
  882. )
  883. if torch.cuda.is_available():
  884. torch.cuda.synchronize()
  885. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  886. codec = None
  887. codec_checkpoint = checkpoint_path / "codec.pth"
  888. # Handle prompt: --prompt-audio takes priority over --prompt-tokens
  889. prompt_tokens_list = None
  890. if prompt_audio:
  891. logger.info("Loading codec model for audio encoding...")
  892. codec = load_codec_model(codec_checkpoint, device, precision)
  893. prompt_tokens_list = [
  894. encode_audio(p, codec, device).cpu() for p in prompt_audio
  895. ]
  896. logger.info(f"Encoded {len(prompt_audio)} audio file(s) to VQ codes")
  897. elif prompt_tokens is not None:
  898. prompt_tokens_list = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
  899. torch.manual_seed(seed)
  900. if torch.cuda.is_available():
  901. torch.cuda.manual_seed(seed)
  902. generator = generate_long(
  903. model=model,
  904. device=device,
  905. decode_one_token=decode_one_token,
  906. text=text,
  907. num_samples=num_samples,
  908. max_new_tokens=max_new_tokens,
  909. top_p=top_p,
  910. top_k=top_k,
  911. temperature=temperature,
  912. compile=compile,
  913. iterative_prompt=iterative_prompt,
  914. chunk_length=chunk_length,
  915. prompt_text=list(prompt_text) if prompt_text else None,
  916. prompt_tokens=prompt_tokens_list,
  917. )
  918. idx = 0
  919. codes = []
  920. for response in generator:
  921. if response.action == "sample":
  922. codes.append(response.codes)
  923. logger.info(f"Sampled text: {response.text}")
  924. elif response.action == "next":
  925. if codes:
  926. merged_codes = torch.cat(codes, dim=1)
  927. codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
  928. np.save(codes_npy_path, merged_codes.cpu().numpy())
  929. logger.info(f"Saved codes to {codes_npy_path}")
  930. # Decode to wav if --output is specified
  931. if output:
  932. if codec is None:
  933. logger.info("Loading codec model for audio decoding...")
  934. codec = load_codec_model(codec_checkpoint, device, precision)
  935. audio = decode_to_audio(merged_codes.to(device), codec)
  936. import soundfile as sf
  937. out_path = (
  938. str(output)
  939. if num_samples == 1
  940. else str(output.with_stem(f"{output.stem}_{idx}"))
  941. )
  942. sf.write(out_path, audio.cpu().float().numpy(), codec.sample_rate)
  943. logger.info(f"Saved audio to {out_path}")
  944. logger.info(f"Next sample")
  945. codes = []
  946. idx += 1
  947. else:
  948. logger.error(f"Error: {response}")
  949. if __name__ == "__main__":
  950. main()