inference.py 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036
  1. import os
  2. import queue
  3. import re
  4. import threading
  5. import time
  6. import traceback
  7. from copy import deepcopy
  8. from dataclasses import dataclass
  9. from pathlib import Path
  10. from typing import Callable, Literal, Optional, Tuple, Union
  11. import click
  12. import numpy as np
  13. import torch
  14. import torch._inductor.config
  15. from loguru import logger
  16. from tqdm import tqdm
  17. from fish_speech.content_sequence import (
  18. TextPart,
  19. VQPart,
  20. )
  21. from fish_speech.conversation import Conversation, Message
  22. from fish_speech.tokenizer import IM_END_TOKEN
  23. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  24. torch._inductor.config.coordinate_descent_tuning = True
  25. torch._inductor.config.triton.unique_kernel_names = True
  26. if hasattr(torch._inductor.config, "fx_graph_cache"):
  27. torch._inductor.config.fx_graph_cache = True
  28. from torch.nn.attention import SDPBackend, sdpa_kernel
  29. from fish_speech.models.text2semantic.llama import (
  30. DualARTransformer,
  31. )
  32. def multinomial_sample_one_no_sync(probs_sort):
  33. q = torch.rand_like(probs_sort)
  34. q = -torch.log(q)
  35. return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
  36. RAS_WIN_SIZE = 10 # window for Repetition Aware Sampling
  37. RAS_HIGH_TEMP = 1.0
  38. RAS_HIGH_TOP_P = 0.9
  39. def logits_to_probs(
  40. logits,
  41. temperature: torch.Tensor,
  42. top_p: torch.Tensor,
  43. top_k: int, # 注意: 我看到你传进来的是 int,这很关键
  44. ) -> torch.Tensor:
  45. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  46. cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
  47. indices = torch.arange(sorted_logits.shape[-1], device=sorted_logits.device)
  48. top_k_mask = indices >= top_k
  49. sorted_indices_to_remove = (cum_probs > top_p) | top_k_mask
  50. sorted_indices_to_remove[0] = False # 单元素修改问题不大,或者写成 | (indices != 0)
  51. indices_to_remove = sorted_indices_to_remove.scatter(
  52. dim=-1, index=sorted_indices, src=sorted_indices_to_remove
  53. )
  54. logits = torch.where(
  55. indices_to_remove, float("-Inf"), logits
  56. ) # 同样替换 masked_fill_ 为 torch.where
  57. logits = logits / torch.clip(temperature, min=1e-5)
  58. probs = torch.nn.functional.softmax(logits, dim=-1)
  59. return probs
  60. def sample(
  61. logits,
  62. temperature: torch.Tensor,
  63. top_p: torch.Tensor,
  64. top_k: int,
  65. ) -> Tuple[torch.Tensor, torch.Tensor]:
  66. probs = logits_to_probs(
  67. logits=logits[0, -1],
  68. temperature=temperature,
  69. top_p=top_p,
  70. top_k=top_k,
  71. )
  72. idx_next = multinomial_sample_one_no_sync(probs)
  73. return idx_next, probs
  74. def decode_one_token_ar(
  75. model: DualARTransformer,
  76. x: torch.Tensor,
  77. input_pos: torch.Tensor,
  78. temperature: torch.Tensor,
  79. top_p: torch.Tensor,
  80. top_k: int,
  81. semantic_logit_bias: torch.Tensor,
  82. audio_masks: torch.Tensor,
  83. audio_parts: torch.Tensor,
  84. previous_tokens: Optional[torch.Tensor] = None,
  85. ) -> torch.Tensor:
  86. forward_result = model.forward_generate(
  87. x,
  88. input_pos,
  89. audio_masks=audio_masks,
  90. audio_parts=audio_parts,
  91. )
  92. logits = forward_result.logits # (1, 1, vocab_size)
  93. hidden_states = forward_result.hidden_states
  94. # Apply constrained decoding: only allow semantic tokens + im_end
  95. biased_logits = logits + semantic_logit_bias
  96. # Normal sample
  97. main_token_normal = sample(
  98. biased_logits, temperature=temperature, top_p=top_p, top_k=top_k
  99. )[0]
  100. # RAS: also sample with high temp to use as fallback if token repeats
  101. high_temp = torch.tensor(
  102. RAS_HIGH_TEMP, device=temperature.device, dtype=temperature.dtype
  103. )
  104. high_top_p = torch.tensor(RAS_HIGH_TOP_P, device=top_p.device, dtype=top_p.dtype)
  105. main_token_high = sample(
  106. biased_logits, temperature=high_temp, top_p=high_top_p, top_k=top_k
  107. )[0]
  108. # Use high-temp sample if: token is semantic AND token is in previous window
  109. if previous_tokens is not None:
  110. in_window = (previous_tokens[0] == main_token_normal).any()
  111. # Use tensor ops (&, torch.where) instead of Python (and, if) — torch.compile requires no data-dependent branching
  112. is_semantic = (main_token_normal >= model.config.semantic_begin_id) & (
  113. main_token_normal <= model.config.semantic_end_id
  114. )
  115. should_use_high = in_window & is_semantic
  116. main_token_normal = torch.where(
  117. should_use_high, main_token_high, main_token_normal
  118. )
  119. codebooks = [main_token_normal]
  120. input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
  121. model.forward_generate_fast(hidden_states, input_pos)
  122. a = codebooks[0] - model.config.semantic_begin_id
  123. a = torch.clamp(a, min=0, max=model.config.codebook_size - 1)
  124. hidden_states = model.fast_embeddings(a)
  125. codebooks.append(a)
  126. for codebook_idx in range(1, model.config.num_codebooks):
  127. input_pos = torch.tensor(
  128. [codebook_idx], device=hidden_states.device, dtype=torch.long
  129. )
  130. logits = model.forward_generate_fast(hidden_states, input_pos)
  131. short_logits = logits # DualAR predicts config.codebook_size number of tokens
  132. # Convert logits to probs (no constrain for fast codebooks)
  133. a = sample(
  134. short_logits,
  135. temperature=temperature,
  136. top_p=top_p,
  137. top_k=top_k,
  138. )[0]
  139. hidden_states = model.fast_embeddings(a)
  140. codebooks.append(a)
  141. codebooks = torch.stack(codebooks, dim=1)
  142. # Only delete references, let Python GC handle cleanup
  143. del logits, hidden_states, forward_result
  144. return codebooks.T
  145. def decode_n_tokens(
  146. model: DualARTransformer,
  147. cur_token: torch.Tensor,
  148. input_pos: torch.Tensor,
  149. num_new_tokens: int,
  150. temperature: torch.Tensor,
  151. top_p: torch.Tensor,
  152. top_k: int,
  153. semantic_logit_bias: torch.Tensor,
  154. audio_masks: torch.Tensor,
  155. audio_parts: torch.Tensor,
  156. decode_one_token=decode_one_token_ar,
  157. ):
  158. # Rolling window for RAS (Repetition Aware Sampling)
  159. previous_tokens = torch.zeros(
  160. (model.config.num_codebooks + 1, RAS_WIN_SIZE),
  161. dtype=torch.int,
  162. device=cur_token.device,
  163. )
  164. # Accumulate all generated tokens (the actual output)
  165. new_tokens = []
  166. # [MODIFIED] Pre-fetch ID for efficiency loop
  167. im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
  168. for i in tqdm(range(num_new_tokens)):
  169. with sdpa_kernel(SDPBackend.MATH):
  170. next_token = decode_one_token(
  171. model=model,
  172. x=cur_token,
  173. input_pos=input_pos,
  174. previous_tokens=previous_tokens,
  175. temperature=temperature,
  176. top_p=top_p,
  177. top_k=top_k,
  178. semantic_logit_bias=semantic_logit_bias,
  179. audio_masks=audio_masks,
  180. audio_parts=audio_parts,
  181. ).clone()
  182. input_pos += 1
  183. cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
  184. # Roll RAS window left and insert new token at end
  185. previous_tokens = previous_tokens.roll(-1, dims=1)
  186. previous_tokens[:, -1] = next_token.view(model.config.num_codebooks + 1, -1)[
  187. :, 0
  188. ]
  189. new_tokens.append(next_token)
  190. if cur_token[0, 0, -1] == im_end_id:
  191. break
  192. del cur_token
  193. return torch.cat(new_tokens, dim=1)
  194. @torch.no_grad()
  195. @torch.inference_mode()
  196. def generate(
  197. *,
  198. model: DualARTransformer,
  199. prompt: torch.Tensor,
  200. max_new_tokens: int,
  201. audio_masks: torch.Tensor,
  202. audio_parts: torch.Tensor,
  203. decode_one_token=decode_one_token_ar,
  204. num_samples: int = 1,
  205. **sampling_kwargs,
  206. ):
  207. """
  208. Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
  209. """
  210. # create an empty tensor of the expected final shape and fill in the current tokens
  211. T = prompt.size(1)
  212. prompt = prompt[None].repeat(num_samples, 1, 1)
  213. if T >= model.config.max_seq_len:
  214. raise ValueError(
  215. f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
  216. )
  217. if max_new_tokens:
  218. if T + max_new_tokens > model.config.max_seq_len:
  219. max_new_tokens = model.config.max_seq_len - T
  220. T_new = T + max_new_tokens
  221. else:
  222. T_new = model.config.max_seq_len
  223. max_new_tokens = T_new - T
  224. device = prompt.device
  225. dtype = next(
  226. model.parameters()
  227. ).dtype # model weight dtype (bfloat16), NOT prompt dtype (int32)
  228. # Critical fix: Only set up cache on first run or when necessary
  229. if not hasattr(model, "_cache_setup_done") or not model._cache_setup_done:
  230. with torch.device(device):
  231. model.setup_caches(
  232. max_batch_size=1, # Fixed to 1, avoid dynamic changes
  233. max_seq_len=model.config.max_seq_len,
  234. dtype=next(model.parameters()).dtype,
  235. )
  236. model._cache_setup_done = True
  237. codebook_dim = 1 + model.config.num_codebooks
  238. # Create new tensor each time, but try to reuse memory
  239. input_pos = torch.arange(0, T, device=device, dtype=torch.long)
  240. empty = torch.empty(
  241. (codebook_dim, model.config.max_seq_len), dtype=prompt.dtype, device=device
  242. )
  243. empty[:, :T] = prompt
  244. seq = empty
  245. temp_val = sampling_kwargs.get("temperature", 1.0)
  246. top_p_val = sampling_kwargs.get("top_p", 0.9)
  247. top_k_val = sampling_kwargs.get("top_k", 30)
  248. temperature = torch.tensor(temp_val, device=device, dtype=dtype)
  249. top_p = torch.tensor(top_p_val, device=device, dtype=dtype)
  250. # Build semantic logit bias: 0 for semantic tokens + im_end, -inf for all others
  251. vocab_size = model.config.vocab_size
  252. semantic_logit_bias = torch.full(
  253. (1, 1, vocab_size), float("-inf"), device=device, dtype=dtype
  254. )
  255. # [MODIFIED] Use config for semantic range
  256. semantic_logit_bias[
  257. 0, 0, model.config.semantic_begin_id : model.config.semantic_end_id + 1
  258. ] = 0.0
  259. # [MODIFIED] Use tokenizer.get_token_id (Wrapper method)
  260. semantic_logit_bias[0, 0, model.tokenizer.get_token_id(IM_END_TOKEN)] = 0.0
  261. prefill_decode = decode_one_token_ar
  262. first_token = prefill_decode(
  263. model,
  264. prompt.view(1, codebook_dim, -1),
  265. input_pos,
  266. temperature,
  267. top_p,
  268. top_k_val,
  269. semantic_logit_bias,
  270. audio_masks,
  271. audio_parts,
  272. )
  273. seq[:, T : T + 1] = first_token
  274. # Recreate input_pos
  275. input_pos = torch.tensor([T], device=device, dtype=torch.int)
  276. x = decode_n_tokens(
  277. model,
  278. first_token.view(1, codebook_dim, -1),
  279. input_pos,
  280. max_new_tokens - 1,
  281. temperature=temperature,
  282. top_p=top_p,
  283. top_k=top_k_val,
  284. semantic_logit_bias=semantic_logit_bias,
  285. audio_masks=audio_masks,
  286. audio_parts=audio_parts,
  287. decode_one_token=decode_one_token,
  288. )
  289. seq = seq[:, : T + 1 + x.size(1)]
  290. seq[:, T + 1 :] = x
  291. # Clean up temporary variables
  292. del first_token, x, prompt, empty, input_pos
  293. return seq
  294. def init_model(checkpoint_path, device, precision, compile=False, quantize=False):
  295. model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
  296. logger.info(f"precision: {precision.__class__.__name__}")
  297. model = model.to(device=device, dtype=precision)
  298. logger.info(f"Restored model from checkpoint")
  299. # Apply INT8 quantization if requested
  300. if quantize:
  301. try:
  302. import bitsandbytes as bnb
  303. logger.info("Applying INT8 quantization with bitsandbytes...")
  304. # Replace all Linear layers with 8-bit quantized versions
  305. def replace_linear_with_int8(module):
  306. for name, child in module.named_children():
  307. if isinstance(child, torch.nn.Linear):
  308. # Create 8-bit linear layer
  309. int8_layer = bnb.nn.Linear8bitLt(
  310. child.in_features,
  311. child.out_features,
  312. bias=child.bias is not None,
  313. has_fp16_weights=False,
  314. threshold=6.0
  315. )
  316. # Copy weights
  317. int8_layer.weight = bnb.nn.Int8Params(
  318. child.weight.data,
  319. requires_grad=False,
  320. has_fp16_weights=False
  321. )
  322. if child.bias is not None:
  323. int8_layer.bias = child.bias
  324. setattr(module, name, int8_layer)
  325. else:
  326. replace_linear_with_int8(child)
  327. replace_linear_with_int8(model)
  328. logger.info("INT8 quantization applied successfully")
  329. except ImportError:
  330. logger.error("bitsandbytes not installed. Install with: pip install bitsandbytes")
  331. raise
  332. if isinstance(model, DualARTransformer):
  333. decode_one_token = decode_one_token_ar
  334. logger.info("Using DualARTransformer")
  335. else:
  336. raise ValueError("Unsupported model type")
  337. # Pre-create fixed parameter tensors to avoid runtime creation
  338. model.fixed_temperature = torch.tensor(0.7, device=device, dtype=torch.float)
  339. model.fixed_top_p = torch.tensor(0.7, device=device, dtype=torch.float)
  340. model.fixed_repetition_penalty = torch.tensor(1.5, device=device, dtype=torch.float)
  341. # Mark whether cache has been initialized
  342. model._cache_setup_done = False
  343. # Disable compile if quantization is enabled (bitsandbytes INT8 is incompatible with torch.compile)
  344. if compile and not quantize:
  345. logger.info("Compiling function...")
  346. decode_one_token = torch.compile(
  347. decode_one_token,
  348. backend="inductor" if torch.cuda.is_available() else "aot_eager",
  349. mode="default" if torch.cuda.is_available() else None,
  350. fullgraph=True,
  351. )
  352. elif compile and quantize:
  353. logger.warning("torch.compile disabled when quantization is enabled (bitsandbytes compatibility)")
  354. return model.eval(), decode_one_token
  355. @torch.inference_mode()
  356. def load_codec_model(codec_checkpoint_path, device, precision=torch.bfloat16):
  357. """Load the DAC codec model for audio encoding/decoding."""
  358. from hydra.utils import instantiate
  359. from omegaconf import OmegaConf
  360. config_path = Path(__file__).parent.parent.parent / "configs" / "modded_dac_vq.yaml"
  361. cfg = OmegaConf.load(str(config_path))
  362. codec = instantiate(cfg)
  363. state_dict = torch.load(codec_checkpoint_path, map_location="cpu")
  364. if "state_dict" in state_dict:
  365. state_dict = state_dict["state_dict"]
  366. if any("generator" in k for k in state_dict):
  367. state_dict = {
  368. k.replace("generator.", ""): v
  369. for k, v in state_dict.items()
  370. if "generator." in k
  371. }
  372. codec.load_state_dict(state_dict, strict=False)
  373. codec.eval()
  374. codec.to(device=device, dtype=precision)
  375. return codec
  376. @torch.inference_mode()
  377. def encode_audio(audio_path, codec, device):
  378. """Encode an audio file to VQ codes."""
  379. import torchaudio
  380. wav, sr = torchaudio.load(str(audio_path))
  381. if wav.shape[0] > 1:
  382. wav = wav.mean(dim=0, keepdim=True)
  383. wav = torchaudio.functional.resample(wav.to(device), sr, codec.sample_rate)[0]
  384. # Match codec model dtype (e.g. bfloat16)
  385. model_dtype = next(codec.parameters()).dtype
  386. audios = wav[None, None].to(dtype=model_dtype) # (1, 1, T)
  387. audio_lengths = torch.tensor([len(wav)], device=device, dtype=torch.long)
  388. indices, feature_lengths = codec.encode(audios, audio_lengths)
  389. return indices[0, :, : feature_lengths[0]] # (num_codebooks, T)
  390. @torch.inference_mode()
  391. def decode_to_audio(codes, codec):
  392. """Decode VQ codes to audio waveform."""
  393. # codes: (num_codebooks, T) -> (1, num_codebooks, T)
  394. audio = codec.from_indices(codes[None])
  395. return audio[0, 0] # (T,) mono waveform
  396. @dataclass
  397. class GenerateResponse:
  398. action: Literal["sample", "next"]
  399. codes: Optional[torch.Tensor] = None
  400. text: Optional[str] = None
  401. def split_text_by_speaker(text: str) -> list[str]:
  402. """
  403. Split text into turns based on <|speaker:X|> tags.
  404. Args:
  405. text: The full text with speaker tags
  406. Returns:
  407. List of speaker turns, each starting with <|speaker:X|>
  408. """
  409. pattern = r"(<\|speaker:\d+\|>)"
  410. parts = re.split(pattern, text)
  411. turns = []
  412. i = 0
  413. while i < len(parts):
  414. part = parts[i].strip()
  415. if re.match(pattern, part):
  416. if i + 1 < len(parts):
  417. turn = part + parts[i + 1]
  418. turns.append(turn.strip())
  419. i += 2
  420. else:
  421. turns.append(part)
  422. i += 1
  423. else:
  424. i += 1
  425. return turns
  426. def group_turns_into_batches(
  427. turns: list[str], max_speakers: int = 3, max_bytes: int = 300
  428. ) -> list[str]:
  429. """
  430. Group turns into batches based on speaker count or byte limit.
  431. Args:
  432. turns: List of speaker turns
  433. max_speakers: Maximum number of speakers per batch (default 3)
  434. max_bytes: Maximum UTF-8 bytes per batch (default 300)
  435. Returns:
  436. List of batched text strings
  437. """
  438. batches = []
  439. current_batch = []
  440. current_bytes = 0
  441. for turn in turns:
  442. turn_bytes = len(turn.encode("utf-8"))
  443. would_exceed_speakers = len(current_batch) >= max_speakers
  444. would_exceed_bytes = current_bytes + turn_bytes > max_bytes and current_batch
  445. if would_exceed_speakers or would_exceed_bytes:
  446. batches.append("\n".join(current_batch))
  447. current_batch = [turn]
  448. current_bytes = turn_bytes
  449. else:
  450. current_batch.append(turn)
  451. current_bytes += turn_bytes
  452. if current_batch:
  453. batches.append("\n".join(current_batch))
  454. return batches
  455. def generate_long(
  456. *,
  457. model,
  458. device: Union[str, torch.device],
  459. decode_one_token: Callable,
  460. text: str,
  461. num_samples: int = 1,
  462. max_new_tokens: int = 0,
  463. top_p: float = 0.9,
  464. top_k: int = 30,
  465. repetition_penalty: float = 1.1,
  466. temperature: float = 1.0,
  467. compile: bool = False,
  468. iterative_prompt: bool = True,
  469. chunk_length: int = 512,
  470. prompt_text: Optional[Union[str, list[str]]] = None,
  471. prompt_tokens: Optional[Union[torch.Tensor, list[torch.Tensor]]] = None,
  472. ):
  473. assert 0 < top_p <= 1, "top_p must be in (0, 1]"
  474. assert 0 < temperature < 2, "temperature must be in (0, 2)"
  475. logger.info(f"generate_long.param.device: {device}")
  476. logger.info(f"generate_long.param.text: {text}")
  477. logger.info(f"generate_long.param.max_new_tokens: {max_new_tokens}")
  478. logger.info(f"generate_long.param.top_p: {top_p}")
  479. logger.info(f"generate_long.param.top_k: {top_k}")
  480. logger.info(f"generate_long.param.temperature: {temperature}")
  481. logger.info(f"generate_long.param.compile: {compile}")
  482. logger.info(f"generate_long.param.chunk_length: {chunk_length}")
  483. logger.info(f"generate_long.param.prompt_text: {prompt_text}")
  484. logger.info(f"generate_long.param.prompt_tokens: {prompt_tokens}")
  485. use_prompt = bool(prompt_text) and bool(prompt_tokens)
  486. if use_prompt and isinstance(prompt_text, str):
  487. prompt_text = [prompt_text]
  488. prompt_tokens = [prompt_tokens]
  489. if use_prompt:
  490. assert len(prompt_text) == len(
  491. prompt_tokens
  492. ), "Prompt text and tokens must have the same length"
  493. if prompt_tokens:
  494. prompt_tokens = [i.cpu() for i in prompt_tokens]
  495. model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
  496. tokenizer = model.tokenizer
  497. max_length = model.config.max_seq_len
  498. # Build base conversation with system message
  499. base_conversation = Conversation()
  500. if use_prompt:
  501. # Auto-add speaker tags to prompt texts that don't have them
  502. tagged_prompt_text = []
  503. for i, t in enumerate(prompt_text):
  504. if not re.search(r"<\|speaker:\d+\|>", t):
  505. tagged_prompt_text.append(f"<|speaker:{i}|>{t}")
  506. else:
  507. tagged_prompt_text.append(t)
  508. system_parts = [
  509. TextPart(
  510. text="convert the provided text to speech reference to the following:\n\nText:\n",
  511. cal_loss=False,
  512. ),
  513. ]
  514. reference_text = "\n".join(tagged_prompt_text)
  515. system_parts.append(TextPart(text=reference_text, cal_loss=False))
  516. system_parts.append(TextPart(text="\n\nSpeech:\n", cal_loss=False))
  517. all_codes = torch.cat([c for c in prompt_tokens], dim=1)
  518. system_parts.append(VQPart(codes=all_codes, cal_loss=False))
  519. # torch.save(all_codes, "debug_vq_codes.pt")
  520. else:
  521. system_parts = [
  522. TextPart(text="convert the provided text to speech", cal_loss=False)
  523. ]
  524. base_conversation.append(
  525. Message(
  526. role="system",
  527. parts=system_parts,
  528. cal_loss=False,
  529. add_im_start=True,
  530. add_im_end=True,
  531. )
  532. )
  533. # Split text by speaker and group into batches
  534. turns = split_text_by_speaker(text)
  535. if turns:
  536. batches = group_turns_into_batches(
  537. turns, max_speakers=5, max_bytes=chunk_length
  538. )
  539. else:
  540. batches = [text]
  541. logger.info(f"Split into {len(turns)} turns, grouped into {len(batches)} batches")
  542. for sample_idx in range(num_samples):
  543. if torch.cuda.is_available():
  544. torch.cuda.synchronize()
  545. t0 = time.perf_counter()
  546. # Deep copy base conversation for this sample
  547. conversation = deepcopy(base_conversation)
  548. for batch_idx, batch_text in enumerate(batches):
  549. logger.info(
  550. f"--- Sample {sample_idx}, Batch {batch_idx} "
  551. f"({len(batch_text.encode('utf-8'))} bytes) ---"
  552. )
  553. logger.info(f"Batch text: {batch_text}")
  554. # Add user message
  555. conversation.append(
  556. Message(
  557. role="user",
  558. parts=[TextPart(text=batch_text, cal_loss=False)],
  559. cal_loss=False,
  560. add_im_start=True,
  561. add_im_end=True,
  562. )
  563. )
  564. # Deep copy for generation (don't pollute original conversation)
  565. conversation_gen = deepcopy(conversation)
  566. conversation_gen.append(
  567. Message(
  568. role="assistant",
  569. parts=[],
  570. cal_loss=False,
  571. modality="voice",
  572. add_im_start=True,
  573. add_im_end=False,
  574. )
  575. )
  576. logger.info("Visualizing prompt structure:")
  577. conversation_gen.visualize(
  578. tokenizer,
  579. merge_audio_tokens=True,
  580. merge_semantic_tokens=True,
  581. )
  582. encoded, audio_masks, audio_parts = conversation_gen.encode_for_inference(
  583. tokenizer, num_codebooks=model.config.num_codebooks
  584. )
  585. logger.info(f"Encoded prompt shape: {encoded.shape}")
  586. if audio_parts is not None:
  587. logger.info(f"Audio parts shape: {audio_parts.shape}")
  588. if audio_masks is not None:
  589. logger.info(
  590. f"Audio masks non-zero count: {torch.count_nonzero(audio_masks)}"
  591. )
  592. if encoded.size(1) > max_length - 2048:
  593. raise ValueError(
  594. f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}"
  595. )
  596. encoded = encoded.to(device=device)
  597. prompt_length = encoded.size(1)
  598. y = generate(
  599. model=model,
  600. prompt=encoded,
  601. max_new_tokens=max_new_tokens,
  602. audio_masks=audio_masks,
  603. audio_parts=audio_parts,
  604. decode_one_token=decode_one_token,
  605. temperature=temperature,
  606. top_p=top_p,
  607. top_k=top_k,
  608. )
  609. if sample_idx == 0 and batch_idx == 0 and compile:
  610. logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
  611. if torch.cuda.is_available():
  612. torch.cuda.synchronize()
  613. t_batch = time.perf_counter() - t0
  614. tokens_generated = y.size(1) - prompt_length
  615. tokens_sec = tokens_generated / t_batch if t_batch > 0 else 0
  616. logger.info(
  617. f"Batch {batch_idx}: Generated {tokens_generated} tokens in "
  618. f"{t_batch:.02f} seconds, {tokens_sec:.02f} tokens/sec"
  619. )
  620. logger.info(
  621. f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
  622. )
  623. # Extract generated codes
  624. codes = y[1:, prompt_length:-1].clone()
  625. assert (codes >= 0).all(), f"Negative code found: {codes}"
  626. # Add assistant message with generated codes back to conversation
  627. conversation.append(
  628. Message(
  629. role="assistant",
  630. parts=[VQPart(codes=codes.cpu(), cal_loss=False)],
  631. cal_loss=False,
  632. modality="voice",
  633. add_im_start=True,
  634. add_im_end=True,
  635. )
  636. )
  637. yield GenerateResponse(action="sample", codes=codes, text=batch_text)
  638. MAX_HISTORY_TURNS = 2 # 只保留最近 2 轮 user/assistant
  639. assistant_indices = [i for i, m in enumerate(conversation.messages) if m.role == "assistant"]
  640. if len(assistant_indices) > MAX_HISTORY_TURNS:
  641. drop = assistant_indices[0]
  642. # 移除最早的 user+assistant 对,保留 system 消息
  643. conversation = Conversation([m for i, m in enumerate(conversation.messages)
  644. if i not in (drop - 1, drop)])
  645. # Cleanup
  646. del y, encoded
  647. if torch.cuda.is_available():
  648. torch.cuda.empty_cache()
  649. import gc
  650. gc.collect()
  651. if torch.cuda.is_available():
  652. logger.info(
  653. f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
  654. )
  655. yield GenerateResponse(action="next")
  656. @dataclass
  657. class WrappedGenerateResponse:
  658. status: Literal["success", "error"]
  659. response: Optional[Union[GenerateResponse, Exception]] = None
  660. @dataclass
  661. class GenerateRequest:
  662. request: dict
  663. response_queue: queue.Queue
  664. def launch_thread_safe_queue(
  665. checkpoint_path,
  666. device,
  667. precision,
  668. compile: bool = False,
  669. num_workers: int = 1,
  670. quantize: bool = False,
  671. ):
  672. input_queue = queue.Queue()
  673. init_events = [threading.Event() for _ in range(num_workers)]
  674. def worker(worker_id, init_event):
  675. logger.info(f"Worker {worker_id} starting, loading model...")
  676. model, decode_one_token = init_model(
  677. checkpoint_path, device, precision, compile=compile, quantize=quantize
  678. )
  679. with torch.device(device):
  680. model.setup_caches(
  681. max_batch_size=1,
  682. max_seq_len=model.config.max_seq_len,
  683. dtype=next(model.parameters()).dtype,
  684. )
  685. logger.info(f"Worker {worker_id} initialized")
  686. init_event.set()
  687. while True:
  688. item: GenerateRequest | None = input_queue.get()
  689. if item is None:
  690. break
  691. kwargs = item.request
  692. response_queue = item.response_queue
  693. try:
  694. for chunk in generate_long(
  695. model=model, decode_one_token=decode_one_token, **kwargs
  696. ):
  697. response_queue.put(
  698. WrappedGenerateResponse(status="success", response=chunk)
  699. )
  700. # Only clear cache after complete request batch
  701. if torch.cuda.is_available():
  702. torch.cuda.empty_cache()
  703. except Exception as e:
  704. logger.error(traceback.format_exc())
  705. response_queue.put(WrappedGenerateResponse(status="error", response=e))
  706. # Clear cache on error
  707. if torch.cuda.is_available():
  708. torch.cuda.empty_cache()
  709. for i in range(num_workers):
  710. threading.Thread(target=worker, args=(i, init_events[i]), daemon=True).start()
  711. for event in init_events:
  712. event.wait()
  713. logger.info(f"All {num_workers} workers initialized successfully")
  714. return input_queue
  715. @click.command()
  716. @click.option(
  717. "--text",
  718. type=str,
  719. default="<|speaker:0|>你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  720. )
  721. @click.option("--prompt-text", type=str, default=None, multiple=True)
  722. @click.option(
  723. "--prompt-tokens",
  724. type=click.Path(path_type=Path, exists=True),
  725. default=None,
  726. multiple=True,
  727. )
  728. @click.option(
  729. "--prompt-audio",
  730. type=click.Path(path_type=Path, exists=True),
  731. default=None,
  732. multiple=True,
  733. )
  734. @click.option("--output", type=click.Path(path_type=Path), default=None)
  735. @click.option("--num-samples", type=int, default=1)
  736. @click.option("--max-new-tokens", type=int, default=0)
  737. @click.option("--top-p", type=float, default=0.9)
  738. @click.option("--top-k", type=int, default=30)
  739. @click.option("--temperature", type=float, default=1.0)
  740. @click.option(
  741. "--checkpoint-path",
  742. type=click.Path(path_type=Path, exists=True),
  743. default="checkpoints/s2-pro",
  744. )
  745. @click.option("--device", type=str, default="cuda")
  746. @click.option("--compile/--no-compile", default=False)
  747. @click.option("--seed", type=int, default=42)
  748. @click.option("--half/--no-half", default=False)
  749. @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
  750. @click.option("--chunk-length", type=int, default=300)
  751. @click.option("--output-dir", type=Path, default="output")
  752. def main(
  753. text: str,
  754. prompt_text: Optional[tuple[str, ...]],
  755. prompt_tokens: Optional[tuple[Path, ...]],
  756. prompt_audio: Optional[tuple[Path, ...]],
  757. output: Optional[Path],
  758. num_samples: int,
  759. max_new_tokens: int,
  760. top_p: float,
  761. top_k: int,
  762. temperature: float,
  763. checkpoint_path: Path,
  764. device: str,
  765. compile: bool,
  766. seed: int,
  767. half: bool,
  768. iterative_prompt: bool,
  769. chunk_length: int,
  770. output_dir: Path,
  771. ) -> None:
  772. os.makedirs(output_dir, exist_ok=True)
  773. precision = torch.half if half else torch.bfloat16
  774. if prompt_text and not prompt_audio and not prompt_tokens:
  775. raise ValueError(
  776. "--prompt-text requires either --prompt-audio or --prompt-tokens"
  777. )
  778. if prompt_text and prompt_tokens and len(prompt_text) != len(prompt_tokens):
  779. raise ValueError(
  780. f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
  781. )
  782. if prompt_text and prompt_audio and len(prompt_text) != len(prompt_audio):
  783. raise ValueError(
  784. f"Number of prompt text ({len(prompt_text)}) and prompt audio ({len(prompt_audio)}) should be the same"
  785. )
  786. logger.info("Loading model ...")
  787. t0 = time.time()
  788. model, decode_one_token = init_model(
  789. checkpoint_path, device, precision, compile=compile
  790. )
  791. with torch.device(device):
  792. model.setup_caches(
  793. max_batch_size=1,
  794. max_seq_len=model.config.max_seq_len,
  795. dtype=next(model.parameters()).dtype,
  796. )
  797. if torch.cuda.is_available():
  798. torch.cuda.synchronize()
  799. logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
  800. codec = None
  801. codec_checkpoint = checkpoint_path / "codec.pth"
  802. # Handle prompt: --prompt-audio takes priority over --prompt-tokens
  803. prompt_tokens_list = None
  804. if prompt_audio:
  805. logger.info("Loading codec model for audio encoding...")
  806. codec = load_codec_model(codec_checkpoint, device, precision)
  807. prompt_tokens_list = [
  808. encode_audio(p, codec, device).cpu() for p in prompt_audio
  809. ]
  810. logger.info(f"Encoded {len(prompt_audio)} audio file(s) to VQ codes")
  811. elif prompt_tokens is not None:
  812. prompt_tokens_list = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
  813. torch.manual_seed(seed)
  814. if torch.cuda.is_available():
  815. torch.cuda.manual_seed(seed)
  816. generator = generate_long(
  817. model=model,
  818. device=device,
  819. decode_one_token=decode_one_token,
  820. text=text,
  821. num_samples=num_samples,
  822. max_new_tokens=max_new_tokens,
  823. top_p=top_p,
  824. top_k=top_k,
  825. temperature=temperature,
  826. compile=compile,
  827. iterative_prompt=iterative_prompt,
  828. chunk_length=chunk_length,
  829. prompt_text=list(prompt_text) if prompt_text else None,
  830. prompt_tokens=prompt_tokens_list,
  831. )
  832. idx = 0
  833. codes = []
  834. for response in generator:
  835. if response.action == "sample":
  836. codes.append(response.codes)
  837. logger.info(f"Sampled text: {response.text}")
  838. elif response.action == "next":
  839. if codes:
  840. merged_codes = torch.cat(codes, dim=1)
  841. codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
  842. np.save(codes_npy_path, merged_codes.cpu().numpy())
  843. logger.info(f"Saved codes to {codes_npy_path}")
  844. # Decode to wav if --output is specified
  845. if output:
  846. if codec is None:
  847. logger.info("Loading codec model for audio decoding...")
  848. codec = load_codec_model(codec_checkpoint, device, precision)
  849. audio = decode_to_audio(merged_codes.to(device), codec)
  850. import soundfile as sf
  851. out_path = (
  852. str(output)
  853. if num_samples == 1
  854. else str(output.with_stem(f"{output.stem}_{idx}"))
  855. )
  856. sf.write(out_path, audio.cpu().float().numpy(), codec.sample_rate)
  857. logger.info(f"Saved audio to {out_path}")
  858. logger.info(f"Next sample")
  859. codes = []
  860. idx += 1
  861. else:
  862. logger.error(f"Error: {response}")
  863. if __name__ == "__main__":
  864. main()