| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036 |
- import os
- import queue
- import re
- import threading
- import time
- import traceback
- from copy import deepcopy
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Callable, Literal, Optional, Tuple, Union
- import click
- import numpy as np
- import torch
- import torch._inductor.config
- from loguru import logger
- from tqdm import tqdm
- from fish_speech.content_sequence import (
- TextPart,
- VQPart,
- )
- from fish_speech.conversation import Conversation, Message
- from fish_speech.tokenizer import IM_END_TOKEN
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
- torch._inductor.config.coordinate_descent_tuning = True
- torch._inductor.config.triton.unique_kernel_names = True
- if hasattr(torch._inductor.config, "fx_graph_cache"):
- torch._inductor.config.fx_graph_cache = True
- from torch.nn.attention import SDPBackend, sdpa_kernel
- from fish_speech.models.text2semantic.llama import (
- DualARTransformer,
- )
- def multinomial_sample_one_no_sync(probs_sort):
- q = torch.rand_like(probs_sort)
- q = -torch.log(q)
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
- RAS_WIN_SIZE = 10 # window for Repetition Aware Sampling
- RAS_HIGH_TEMP = 1.0
- RAS_HIGH_TOP_P = 0.9
- def logits_to_probs(
- logits,
- temperature: torch.Tensor,
- top_p: torch.Tensor,
- top_k: int, # 注意: 我看到你传进来的是 int,这很关键
- ) -> torch.Tensor:
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
- cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
- indices = torch.arange(sorted_logits.shape[-1], device=sorted_logits.device)
- top_k_mask = indices >= top_k
- sorted_indices_to_remove = (cum_probs > top_p) | top_k_mask
- sorted_indices_to_remove[0] = False # 单元素修改问题不大,或者写成 | (indices != 0)
- indices_to_remove = sorted_indices_to_remove.scatter(
- dim=-1, index=sorted_indices, src=sorted_indices_to_remove
- )
- logits = torch.where(
- indices_to_remove, float("-Inf"), logits
- ) # 同样替换 masked_fill_ 为 torch.where
- logits = logits / torch.clip(temperature, min=1e-5)
- probs = torch.nn.functional.softmax(logits, dim=-1)
- return probs
- def sample(
- logits,
- temperature: torch.Tensor,
- top_p: torch.Tensor,
- top_k: int,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- probs = logits_to_probs(
- logits=logits[0, -1],
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- )
- idx_next = multinomial_sample_one_no_sync(probs)
- return idx_next, probs
- def decode_one_token_ar(
- model: DualARTransformer,
- x: torch.Tensor,
- input_pos: torch.Tensor,
- temperature: torch.Tensor,
- top_p: torch.Tensor,
- top_k: int,
- semantic_logit_bias: torch.Tensor,
- audio_masks: torch.Tensor,
- audio_parts: torch.Tensor,
- previous_tokens: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- forward_result = model.forward_generate(
- x,
- input_pos,
- audio_masks=audio_masks,
- audio_parts=audio_parts,
- )
- logits = forward_result.logits # (1, 1, vocab_size)
- hidden_states = forward_result.hidden_states
- # Apply constrained decoding: only allow semantic tokens + im_end
- biased_logits = logits + semantic_logit_bias
- # Normal sample
- main_token_normal = sample(
- biased_logits, temperature=temperature, top_p=top_p, top_k=top_k
- )[0]
- # RAS: also sample with high temp to use as fallback if token repeats
- high_temp = torch.tensor(
- RAS_HIGH_TEMP, device=temperature.device, dtype=temperature.dtype
- )
- high_top_p = torch.tensor(RAS_HIGH_TOP_P, device=top_p.device, dtype=top_p.dtype)
- main_token_high = sample(
- biased_logits, temperature=high_temp, top_p=high_top_p, top_k=top_k
- )[0]
- # Use high-temp sample if: token is semantic AND token is in previous window
- if previous_tokens is not None:
- in_window = (previous_tokens[0] == main_token_normal).any()
- # Use tensor ops (&, torch.where) instead of Python (and, if) — torch.compile requires no data-dependent branching
- is_semantic = (main_token_normal >= model.config.semantic_begin_id) & (
- main_token_normal <= model.config.semantic_end_id
- )
- should_use_high = in_window & is_semantic
- main_token_normal = torch.where(
- should_use_high, main_token_high, main_token_normal
- )
- codebooks = [main_token_normal]
- input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
- model.forward_generate_fast(hidden_states, input_pos)
- a = codebooks[0] - model.config.semantic_begin_id
- a = torch.clamp(a, min=0, max=model.config.codebook_size - 1)
- hidden_states = model.fast_embeddings(a)
- codebooks.append(a)
- for codebook_idx in range(1, model.config.num_codebooks):
- input_pos = torch.tensor(
- [codebook_idx], device=hidden_states.device, dtype=torch.long
- )
- logits = model.forward_generate_fast(hidden_states, input_pos)
- short_logits = logits # DualAR predicts config.codebook_size number of tokens
- # Convert logits to probs (no constrain for fast codebooks)
- a = sample(
- short_logits,
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- )[0]
- hidden_states = model.fast_embeddings(a)
- codebooks.append(a)
- codebooks = torch.stack(codebooks, dim=1)
- # Only delete references, let Python GC handle cleanup
- del logits, hidden_states, forward_result
- return codebooks.T
- def decode_n_tokens(
- model: DualARTransformer,
- cur_token: torch.Tensor,
- input_pos: torch.Tensor,
- num_new_tokens: int,
- temperature: torch.Tensor,
- top_p: torch.Tensor,
- top_k: int,
- semantic_logit_bias: torch.Tensor,
- audio_masks: torch.Tensor,
- audio_parts: torch.Tensor,
- decode_one_token=decode_one_token_ar,
- ):
- # Rolling window for RAS (Repetition Aware Sampling)
- previous_tokens = torch.zeros(
- (model.config.num_codebooks + 1, RAS_WIN_SIZE),
- dtype=torch.int,
- device=cur_token.device,
- )
- # Accumulate all generated tokens (the actual output)
- new_tokens = []
- # [MODIFIED] Pre-fetch ID for efficiency loop
- im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
- for i in tqdm(range(num_new_tokens)):
- with sdpa_kernel(SDPBackend.MATH):
- next_token = decode_one_token(
- model=model,
- x=cur_token,
- input_pos=input_pos,
- previous_tokens=previous_tokens,
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- semantic_logit_bias=semantic_logit_bias,
- audio_masks=audio_masks,
- audio_parts=audio_parts,
- ).clone()
- input_pos += 1
- cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
- # Roll RAS window left and insert new token at end
- previous_tokens = previous_tokens.roll(-1, dims=1)
- previous_tokens[:, -1] = next_token.view(model.config.num_codebooks + 1, -1)[
- :, 0
- ]
- new_tokens.append(next_token)
- if cur_token[0, 0, -1] == im_end_id:
- break
- del cur_token
- return torch.cat(new_tokens, dim=1)
- @torch.no_grad()
- @torch.inference_mode()
- def generate(
- *,
- model: DualARTransformer,
- prompt: torch.Tensor,
- max_new_tokens: int,
- audio_masks: torch.Tensor,
- audio_parts: torch.Tensor,
- decode_one_token=decode_one_token_ar,
- num_samples: int = 1,
- **sampling_kwargs,
- ):
- """
- Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
- """
- # create an empty tensor of the expected final shape and fill in the current tokens
- T = prompt.size(1)
- prompt = prompt[None].repeat(num_samples, 1, 1)
- if T >= model.config.max_seq_len:
- raise ValueError(
- f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
- )
- if max_new_tokens:
- if T + max_new_tokens > model.config.max_seq_len:
- max_new_tokens = model.config.max_seq_len - T
- T_new = T + max_new_tokens
- else:
- T_new = model.config.max_seq_len
- max_new_tokens = T_new - T
- device = prompt.device
- dtype = next(
- model.parameters()
- ).dtype # model weight dtype (bfloat16), NOT prompt dtype (int32)
- # Critical fix: Only set up cache on first run or when necessary
- if not hasattr(model, "_cache_setup_done") or not model._cache_setup_done:
- with torch.device(device):
- model.setup_caches(
- max_batch_size=1, # Fixed to 1, avoid dynamic changes
- max_seq_len=model.config.max_seq_len,
- dtype=next(model.parameters()).dtype,
- )
- model._cache_setup_done = True
- codebook_dim = 1 + model.config.num_codebooks
- # Create new tensor each time, but try to reuse memory
- input_pos = torch.arange(0, T, device=device, dtype=torch.long)
- empty = torch.empty(
- (codebook_dim, model.config.max_seq_len), dtype=prompt.dtype, device=device
- )
- empty[:, :T] = prompt
- seq = empty
- temp_val = sampling_kwargs.get("temperature", 1.0)
- top_p_val = sampling_kwargs.get("top_p", 0.9)
- top_k_val = sampling_kwargs.get("top_k", 30)
- temperature = torch.tensor(temp_val, device=device, dtype=dtype)
- top_p = torch.tensor(top_p_val, device=device, dtype=dtype)
- # Build semantic logit bias: 0 for semantic tokens + im_end, -inf for all others
- vocab_size = model.config.vocab_size
- semantic_logit_bias = torch.full(
- (1, 1, vocab_size), float("-inf"), device=device, dtype=dtype
- )
- # [MODIFIED] Use config for semantic range
- semantic_logit_bias[
- 0, 0, model.config.semantic_begin_id : model.config.semantic_end_id + 1
- ] = 0.0
- # [MODIFIED] Use tokenizer.get_token_id (Wrapper method)
- semantic_logit_bias[0, 0, model.tokenizer.get_token_id(IM_END_TOKEN)] = 0.0
- prefill_decode = decode_one_token_ar
- first_token = prefill_decode(
- model,
- prompt.view(1, codebook_dim, -1),
- input_pos,
- temperature,
- top_p,
- top_k_val,
- semantic_logit_bias,
- audio_masks,
- audio_parts,
- )
- seq[:, T : T + 1] = first_token
- # Recreate input_pos
- input_pos = torch.tensor([T], device=device, dtype=torch.int)
- x = decode_n_tokens(
- model,
- first_token.view(1, codebook_dim, -1),
- input_pos,
- max_new_tokens - 1,
- temperature=temperature,
- top_p=top_p,
- top_k=top_k_val,
- semantic_logit_bias=semantic_logit_bias,
- audio_masks=audio_masks,
- audio_parts=audio_parts,
- decode_one_token=decode_one_token,
- )
- seq = seq[:, : T + 1 + x.size(1)]
- seq[:, T + 1 :] = x
- # Clean up temporary variables
- del first_token, x, prompt, empty, input_pos
- return seq
- def init_model(checkpoint_path, device, precision, compile=False, quantize=False):
- model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
- logger.info(f"precision: {precision.__class__.__name__}")
- model = model.to(device=device, dtype=precision)
- logger.info(f"Restored model from checkpoint")
- # Apply INT8 quantization if requested
- if quantize:
- try:
- import bitsandbytes as bnb
- logger.info("Applying INT8 quantization with bitsandbytes...")
- # Replace all Linear layers with 8-bit quantized versions
- def replace_linear_with_int8(module):
- for name, child in module.named_children():
- if isinstance(child, torch.nn.Linear):
- # Create 8-bit linear layer
- int8_layer = bnb.nn.Linear8bitLt(
- child.in_features,
- child.out_features,
- bias=child.bias is not None,
- has_fp16_weights=False,
- threshold=6.0
- )
- # Copy weights
- int8_layer.weight = bnb.nn.Int8Params(
- child.weight.data,
- requires_grad=False,
- has_fp16_weights=False
- )
- if child.bias is not None:
- int8_layer.bias = child.bias
- setattr(module, name, int8_layer)
- else:
- replace_linear_with_int8(child)
- replace_linear_with_int8(model)
- logger.info("INT8 quantization applied successfully")
- except ImportError:
- logger.error("bitsandbytes not installed. Install with: pip install bitsandbytes")
- raise
- if isinstance(model, DualARTransformer):
- decode_one_token = decode_one_token_ar
- logger.info("Using DualARTransformer")
- else:
- raise ValueError("Unsupported model type")
- # Pre-create fixed parameter tensors to avoid runtime creation
- model.fixed_temperature = torch.tensor(0.7, device=device, dtype=torch.float)
- model.fixed_top_p = torch.tensor(0.7, device=device, dtype=torch.float)
- model.fixed_repetition_penalty = torch.tensor(1.5, device=device, dtype=torch.float)
- # Mark whether cache has been initialized
- model._cache_setup_done = False
- # Disable compile if quantization is enabled (bitsandbytes INT8 is incompatible with torch.compile)
- if compile and not quantize:
- logger.info("Compiling function...")
- decode_one_token = torch.compile(
- decode_one_token,
- backend="inductor" if torch.cuda.is_available() else "aot_eager",
- mode="default" if torch.cuda.is_available() else None,
- fullgraph=True,
- )
- elif compile and quantize:
- logger.warning("torch.compile disabled when quantization is enabled (bitsandbytes compatibility)")
- return model.eval(), decode_one_token
- @torch.inference_mode()
- def load_codec_model(codec_checkpoint_path, device, precision=torch.bfloat16):
- """Load the DAC codec model for audio encoding/decoding."""
- from hydra.utils import instantiate
- from omegaconf import OmegaConf
- config_path = Path(__file__).parent.parent.parent / "configs" / "modded_dac_vq.yaml"
- cfg = OmegaConf.load(str(config_path))
- codec = instantiate(cfg)
- state_dict = torch.load(codec_checkpoint_path, map_location="cpu")
- if "state_dict" in state_dict:
- state_dict = state_dict["state_dict"]
- if any("generator" in k for k in state_dict):
- state_dict = {
- k.replace("generator.", ""): v
- for k, v in state_dict.items()
- if "generator." in k
- }
- codec.load_state_dict(state_dict, strict=False)
- codec.eval()
- codec.to(device=device, dtype=precision)
- return codec
- @torch.inference_mode()
- def encode_audio(audio_path, codec, device):
- """Encode an audio file to VQ codes."""
- import torchaudio
- wav, sr = torchaudio.load(str(audio_path))
- if wav.shape[0] > 1:
- wav = wav.mean(dim=0, keepdim=True)
- wav = torchaudio.functional.resample(wav.to(device), sr, codec.sample_rate)[0]
- # Match codec model dtype (e.g. bfloat16)
- model_dtype = next(codec.parameters()).dtype
- audios = wav[None, None].to(dtype=model_dtype) # (1, 1, T)
- audio_lengths = torch.tensor([len(wav)], device=device, dtype=torch.long)
- indices, feature_lengths = codec.encode(audios, audio_lengths)
- return indices[0, :, : feature_lengths[0]] # (num_codebooks, T)
- @torch.inference_mode()
- def decode_to_audio(codes, codec):
- """Decode VQ codes to audio waveform."""
- # codes: (num_codebooks, T) -> (1, num_codebooks, T)
- audio = codec.from_indices(codes[None])
- return audio[0, 0] # (T,) mono waveform
- @dataclass
- class GenerateResponse:
- action: Literal["sample", "next"]
- codes: Optional[torch.Tensor] = None
- text: Optional[str] = None
- def split_text_by_speaker(text: str) -> list[str]:
- """
- Split text into turns based on <|speaker:X|> tags.
- Args:
- text: The full text with speaker tags
- Returns:
- List of speaker turns, each starting with <|speaker:X|>
- """
- pattern = r"(<\|speaker:\d+\|>)"
- parts = re.split(pattern, text)
- turns = []
- i = 0
- while i < len(parts):
- part = parts[i].strip()
- if re.match(pattern, part):
- if i + 1 < len(parts):
- turn = part + parts[i + 1]
- turns.append(turn.strip())
- i += 2
- else:
- turns.append(part)
- i += 1
- else:
- i += 1
- return turns
- def group_turns_into_batches(
- turns: list[str], max_speakers: int = 3, max_bytes: int = 300
- ) -> list[str]:
- """
- Group turns into batches based on speaker count or byte limit.
- Args:
- turns: List of speaker turns
- max_speakers: Maximum number of speakers per batch (default 3)
- max_bytes: Maximum UTF-8 bytes per batch (default 300)
- Returns:
- List of batched text strings
- """
- batches = []
- current_batch = []
- current_bytes = 0
- for turn in turns:
- turn_bytes = len(turn.encode("utf-8"))
- would_exceed_speakers = len(current_batch) >= max_speakers
- would_exceed_bytes = current_bytes + turn_bytes > max_bytes and current_batch
- if would_exceed_speakers or would_exceed_bytes:
- batches.append("\n".join(current_batch))
- current_batch = [turn]
- current_bytes = turn_bytes
- else:
- current_batch.append(turn)
- current_bytes += turn_bytes
- if current_batch:
- batches.append("\n".join(current_batch))
- return batches
- def generate_long(
- *,
- model,
- device: Union[str, torch.device],
- decode_one_token: Callable,
- text: str,
- num_samples: int = 1,
- max_new_tokens: int = 0,
- top_p: float = 0.9,
- top_k: int = 30,
- repetition_penalty: float = 1.1,
- temperature: float = 1.0,
- compile: bool = False,
- iterative_prompt: bool = True,
- chunk_length: int = 512,
- prompt_text: Optional[Union[str, list[str]]] = None,
- prompt_tokens: Optional[Union[torch.Tensor, list[torch.Tensor]]] = None,
- ):
- assert 0 < top_p <= 1, "top_p must be in (0, 1]"
- assert 0 < temperature < 2, "temperature must be in (0, 2)"
- logger.info(f"generate_long.param.device: {device}")
- logger.info(f"generate_long.param.text: {text}")
- logger.info(f"generate_long.param.max_new_tokens: {max_new_tokens}")
- logger.info(f"generate_long.param.top_p: {top_p}")
- logger.info(f"generate_long.param.top_k: {top_k}")
- logger.info(f"generate_long.param.temperature: {temperature}")
- logger.info(f"generate_long.param.compile: {compile}")
- logger.info(f"generate_long.param.chunk_length: {chunk_length}")
- logger.info(f"generate_long.param.prompt_text: {prompt_text}")
- logger.info(f"generate_long.param.prompt_tokens: {prompt_tokens}")
- use_prompt = bool(prompt_text) and bool(prompt_tokens)
- if use_prompt and isinstance(prompt_text, str):
- prompt_text = [prompt_text]
- prompt_tokens = [prompt_tokens]
- if use_prompt:
- assert len(prompt_text) == len(
- prompt_tokens
- ), "Prompt text and tokens must have the same length"
- if prompt_tokens:
- prompt_tokens = [i.cpu() for i in prompt_tokens]
- model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
- tokenizer = model.tokenizer
- max_length = model.config.max_seq_len
- # Build base conversation with system message
- base_conversation = Conversation()
- if use_prompt:
- # Auto-add speaker tags to prompt texts that don't have them
- tagged_prompt_text = []
- for i, t in enumerate(prompt_text):
- if not re.search(r"<\|speaker:\d+\|>", t):
- tagged_prompt_text.append(f"<|speaker:{i}|>{t}")
- else:
- tagged_prompt_text.append(t)
- system_parts = [
- TextPart(
- text="convert the provided text to speech reference to the following:\n\nText:\n",
- cal_loss=False,
- ),
- ]
- reference_text = "\n".join(tagged_prompt_text)
- system_parts.append(TextPart(text=reference_text, cal_loss=False))
- system_parts.append(TextPart(text="\n\nSpeech:\n", cal_loss=False))
- all_codes = torch.cat([c for c in prompt_tokens], dim=1)
- system_parts.append(VQPart(codes=all_codes, cal_loss=False))
- # torch.save(all_codes, "debug_vq_codes.pt")
- else:
- system_parts = [
- TextPart(text="convert the provided text to speech", cal_loss=False)
- ]
- base_conversation.append(
- Message(
- role="system",
- parts=system_parts,
- cal_loss=False,
- add_im_start=True,
- add_im_end=True,
- )
- )
- # Split text by speaker and group into batches
- turns = split_text_by_speaker(text)
- if turns:
- batches = group_turns_into_batches(
- turns, max_speakers=5, max_bytes=chunk_length
- )
- else:
- batches = [text]
- logger.info(f"Split into {len(turns)} turns, grouped into {len(batches)} batches")
- for sample_idx in range(num_samples):
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- t0 = time.perf_counter()
- # Deep copy base conversation for this sample
- conversation = deepcopy(base_conversation)
- for batch_idx, batch_text in enumerate(batches):
- logger.info(
- f"--- Sample {sample_idx}, Batch {batch_idx} "
- f"({len(batch_text.encode('utf-8'))} bytes) ---"
- )
- logger.info(f"Batch text: {batch_text}")
- # Add user message
- conversation.append(
- Message(
- role="user",
- parts=[TextPart(text=batch_text, cal_loss=False)],
- cal_loss=False,
- add_im_start=True,
- add_im_end=True,
- )
- )
- # Deep copy for generation (don't pollute original conversation)
- conversation_gen = deepcopy(conversation)
- conversation_gen.append(
- Message(
- role="assistant",
- parts=[],
- cal_loss=False,
- modality="voice",
- add_im_start=True,
- add_im_end=False,
- )
- )
- logger.info("Visualizing prompt structure:")
- conversation_gen.visualize(
- tokenizer,
- merge_audio_tokens=True,
- merge_semantic_tokens=True,
- )
- encoded, audio_masks, audio_parts = conversation_gen.encode_for_inference(
- tokenizer, num_codebooks=model.config.num_codebooks
- )
- logger.info(f"Encoded prompt shape: {encoded.shape}")
- if audio_parts is not None:
- logger.info(f"Audio parts shape: {audio_parts.shape}")
- if audio_masks is not None:
- logger.info(
- f"Audio masks non-zero count: {torch.count_nonzero(audio_masks)}"
- )
- if encoded.size(1) > max_length - 2048:
- raise ValueError(
- f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}"
- )
- encoded = encoded.to(device=device)
- prompt_length = encoded.size(1)
- y = generate(
- model=model,
- prompt=encoded,
- max_new_tokens=max_new_tokens,
- audio_masks=audio_masks,
- audio_parts=audio_parts,
- decode_one_token=decode_one_token,
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- )
- if sample_idx == 0 and batch_idx == 0 and compile:
- logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- t_batch = time.perf_counter() - t0
- tokens_generated = y.size(1) - prompt_length
- tokens_sec = tokens_generated / t_batch if t_batch > 0 else 0
- logger.info(
- f"Batch {batch_idx}: Generated {tokens_generated} tokens in "
- f"{t_batch:.02f} seconds, {tokens_sec:.02f} tokens/sec"
- )
- logger.info(
- f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
- )
- # Extract generated codes
- codes = y[1:, prompt_length:-1].clone()
- assert (codes >= 0).all(), f"Negative code found: {codes}"
- # Add assistant message with generated codes back to conversation
- conversation.append(
- Message(
- role="assistant",
- parts=[VQPart(codes=codes.cpu(), cal_loss=False)],
- cal_loss=False,
- modality="voice",
- add_im_start=True,
- add_im_end=True,
- )
- )
- yield GenerateResponse(action="sample", codes=codes, text=batch_text)
- MAX_HISTORY_TURNS = 2 # 只保留最近 2 轮 user/assistant
- assistant_indices = [i for i, m in enumerate(conversation.messages) if m.role == "assistant"]
- if len(assistant_indices) > MAX_HISTORY_TURNS:
- drop = assistant_indices[0]
- # 移除最早的 user+assistant 对,保留 system 消息
- conversation = Conversation([m for i, m in enumerate(conversation.messages)
- if i not in (drop - 1, drop)])
- # Cleanup
- del y, encoded
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- import gc
- gc.collect()
- if torch.cuda.is_available():
- logger.info(
- f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
- )
- yield GenerateResponse(action="next")
- @dataclass
- class WrappedGenerateResponse:
- status: Literal["success", "error"]
- response: Optional[Union[GenerateResponse, Exception]] = None
- @dataclass
- class GenerateRequest:
- request: dict
- response_queue: queue.Queue
- def launch_thread_safe_queue(
- checkpoint_path,
- device,
- precision,
- compile: bool = False,
- num_workers: int = 1,
- quantize: bool = False,
- ):
- input_queue = queue.Queue()
- init_events = [threading.Event() for _ in range(num_workers)]
- def worker(worker_id, init_event):
- logger.info(f"Worker {worker_id} starting, loading model...")
- model, decode_one_token = init_model(
- checkpoint_path, device, precision, compile=compile, quantize=quantize
- )
- with torch.device(device):
- model.setup_caches(
- max_batch_size=1,
- max_seq_len=model.config.max_seq_len,
- dtype=next(model.parameters()).dtype,
- )
- logger.info(f"Worker {worker_id} initialized")
- init_event.set()
- while True:
- item: GenerateRequest | None = input_queue.get()
- if item is None:
- break
- kwargs = item.request
- response_queue = item.response_queue
- try:
- for chunk in generate_long(
- model=model, decode_one_token=decode_one_token, **kwargs
- ):
- response_queue.put(
- WrappedGenerateResponse(status="success", response=chunk)
- )
- # Only clear cache after complete request batch
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- except Exception as e:
- logger.error(traceback.format_exc())
- response_queue.put(WrappedGenerateResponse(status="error", response=e))
- # Clear cache on error
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- for i in range(num_workers):
- threading.Thread(target=worker, args=(i, init_events[i]), daemon=True).start()
- for event in init_events:
- event.wait()
- logger.info(f"All {num_workers} workers initialized successfully")
- return input_queue
- @click.command()
- @click.option(
- "--text",
- type=str,
- default="<|speaker:0|>你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
- )
- @click.option("--prompt-text", type=str, default=None, multiple=True)
- @click.option(
- "--prompt-tokens",
- type=click.Path(path_type=Path, exists=True),
- default=None,
- multiple=True,
- )
- @click.option(
- "--prompt-audio",
- type=click.Path(path_type=Path, exists=True),
- default=None,
- multiple=True,
- )
- @click.option("--output", type=click.Path(path_type=Path), default=None)
- @click.option("--num-samples", type=int, default=1)
- @click.option("--max-new-tokens", type=int, default=0)
- @click.option("--top-p", type=float, default=0.9)
- @click.option("--top-k", type=int, default=30)
- @click.option("--temperature", type=float, default=1.0)
- @click.option(
- "--checkpoint-path",
- type=click.Path(path_type=Path, exists=True),
- default="checkpoints/s2-pro",
- )
- @click.option("--device", type=str, default="cuda")
- @click.option("--compile/--no-compile", default=False)
- @click.option("--seed", type=int, default=42)
- @click.option("--half/--no-half", default=False)
- @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
- @click.option("--chunk-length", type=int, default=300)
- @click.option("--output-dir", type=Path, default="output")
- def main(
- text: str,
- prompt_text: Optional[tuple[str, ...]],
- prompt_tokens: Optional[tuple[Path, ...]],
- prompt_audio: Optional[tuple[Path, ...]],
- output: Optional[Path],
- num_samples: int,
- max_new_tokens: int,
- top_p: float,
- top_k: int,
- temperature: float,
- checkpoint_path: Path,
- device: str,
- compile: bool,
- seed: int,
- half: bool,
- iterative_prompt: bool,
- chunk_length: int,
- output_dir: Path,
- ) -> None:
- os.makedirs(output_dir, exist_ok=True)
- precision = torch.half if half else torch.bfloat16
- if prompt_text and not prompt_audio and not prompt_tokens:
- raise ValueError(
- "--prompt-text requires either --prompt-audio or --prompt-tokens"
- )
- if prompt_text and prompt_tokens and len(prompt_text) != len(prompt_tokens):
- raise ValueError(
- f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
- )
- if prompt_text and prompt_audio and len(prompt_text) != len(prompt_audio):
- raise ValueError(
- f"Number of prompt text ({len(prompt_text)}) and prompt audio ({len(prompt_audio)}) should be the same"
- )
- logger.info("Loading model ...")
- t0 = time.time()
- model, decode_one_token = init_model(
- checkpoint_path, device, precision, compile=compile
- )
- with torch.device(device):
- model.setup_caches(
- max_batch_size=1,
- max_seq_len=model.config.max_seq_len,
- dtype=next(model.parameters()).dtype,
- )
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
- codec = None
- codec_checkpoint = checkpoint_path / "codec.pth"
- # Handle prompt: --prompt-audio takes priority over --prompt-tokens
- prompt_tokens_list = None
- if prompt_audio:
- logger.info("Loading codec model for audio encoding...")
- codec = load_codec_model(codec_checkpoint, device, precision)
- prompt_tokens_list = [
- encode_audio(p, codec, device).cpu() for p in prompt_audio
- ]
- logger.info(f"Encoded {len(prompt_audio)} audio file(s) to VQ codes")
- elif prompt_tokens is not None:
- prompt_tokens_list = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- generator = generate_long(
- model=model,
- device=device,
- decode_one_token=decode_one_token,
- text=text,
- num_samples=num_samples,
- max_new_tokens=max_new_tokens,
- top_p=top_p,
- top_k=top_k,
- temperature=temperature,
- compile=compile,
- iterative_prompt=iterative_prompt,
- chunk_length=chunk_length,
- prompt_text=list(prompt_text) if prompt_text else None,
- prompt_tokens=prompt_tokens_list,
- )
- idx = 0
- codes = []
- for response in generator:
- if response.action == "sample":
- codes.append(response.codes)
- logger.info(f"Sampled text: {response.text}")
- elif response.action == "next":
- if codes:
- merged_codes = torch.cat(codes, dim=1)
- codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
- np.save(codes_npy_path, merged_codes.cpu().numpy())
- logger.info(f"Saved codes to {codes_npy_path}")
- # Decode to wav if --output is specified
- if output:
- if codec is None:
- logger.info("Loading codec model for audio decoding...")
- codec = load_codec_model(codec_checkpoint, device, precision)
- audio = decode_to_audio(merged_codes.to(device), codec)
- import soundfile as sf
- out_path = (
- str(output)
- if num_samples == 1
- else str(output.with_stem(f"{output.stem}_{idx}"))
- )
- sf.write(out_path, audio.cpu().float().numpy(), codec.sample_rate)
- logger.info(f"Saved audio to {out_path}")
- logger.info(f"Next sample")
- codes = []
- idx += 1
- else:
- logger.error(f"Error: {response}")
- if __name__ == "__main__":
- main()
|