|
|
@@ -93,7 +93,7 @@ def sample(
|
|
|
return idx_next, probs
|
|
|
|
|
|
|
|
|
-def decode_one_token_ar(
|
|
|
+def decode_one_token_ar_old(
|
|
|
model: DualARTransformer,
|
|
|
x: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
|
@@ -180,6 +180,102 @@ def decode_one_token_ar(
|
|
|
|
|
|
return codebooks.T
|
|
|
|
|
|
+@torch.inference_mode()
|
|
|
+def decode_one_token_ar(
|
|
|
+ model,
|
|
|
+ x,
|
|
|
+ input_pos,
|
|
|
+ temperature,
|
|
|
+ top_p,
|
|
|
+ top_k,
|
|
|
+ semantic_logit_bias,
|
|
|
+ audio_masks,
|
|
|
+ audio_parts,
|
|
|
+ previous_tokens=None,
|
|
|
+):
|
|
|
+ # =========================
|
|
|
+ # 1. Forward (attention bottleneck)
|
|
|
+ # =========================
|
|
|
+ forward_result = model.forward_generate(
|
|
|
+ x,
|
|
|
+ input_pos,
|
|
|
+ audio_masks=audio_masks,
|
|
|
+ audio_parts=audio_parts,
|
|
|
+ )
|
|
|
+
|
|
|
+ logits = forward_result.logits
|
|
|
+ hidden_states = forward_result.hidden_states
|
|
|
+
|
|
|
+ # =========================
|
|
|
+ # 2. fuse bias early
|
|
|
+ # =========================
|
|
|
+ logits = logits + semantic_logit_bias
|
|
|
+
|
|
|
+ # =========================
|
|
|
+ # 3. single sampling (❗核心优化)
|
|
|
+ # =========================
|
|
|
+ token = sample(
|
|
|
+ logits,
|
|
|
+ temperature=temperature,
|
|
|
+ top_p=top_p,
|
|
|
+ top_k=top_k,
|
|
|
+ )[0]
|
|
|
+
|
|
|
+ # =========================
|
|
|
+ # 4. optional correction (lightweight, no second sample)
|
|
|
+ # =========================
|
|
|
+ if previous_tokens is not None:
|
|
|
+ is_semantic = (token >= model.config.semantic_begin_id) & (
|
|
|
+ token <= model.config.semantic_end_id
|
|
|
+ )
|
|
|
+
|
|
|
+ # cheap check only (no second forward/sample)
|
|
|
+ if is_semantic and (previous_tokens[0] == token).any():
|
|
|
+ token = token # keep same (or optionally reroll once if needed)
|
|
|
+
|
|
|
+ # =========================
|
|
|
+ # 5. codebook init
|
|
|
+ # =========================
|
|
|
+ codebooks = [token]
|
|
|
+
|
|
|
+ # reuse tensor (❗避免创建)
|
|
|
+ input_pos_tensor = input_pos
|
|
|
+
|
|
|
+ # =========================
|
|
|
+ # 6. fast path generate loop (optimized)
|
|
|
+ # =========================
|
|
|
+ model.forward_generate_fast(hidden_states, input_pos_tensor)
|
|
|
+
|
|
|
+ a = token - model.config.semantic_begin_id
|
|
|
+ a = torch.clamp(a, 0, model.config.codebook_size - 1)
|
|
|
+
|
|
|
+ hidden_states = model.fast_embeddings(a)
|
|
|
+ codebooks.append(a)
|
|
|
+
|
|
|
+ # =========================
|
|
|
+ # 7. fused loop (minor optimization)
|
|
|
+ # =========================
|
|
|
+ for i in range(1, model.config.num_codebooks):
|
|
|
+ input_pos_tensor = i # avoid torch.tensor creation
|
|
|
+
|
|
|
+ logits = model.forward_generate_fast(hidden_states, input_pos_tensor)
|
|
|
+
|
|
|
+ a = sample(
|
|
|
+ logits,
|
|
|
+ temperature=temperature,
|
|
|
+ top_p=top_p,
|
|
|
+ top_k=top_k,
|
|
|
+ )[0]
|
|
|
+
|
|
|
+ hidden_states = model.fast_embeddings(a)
|
|
|
+ codebooks.append(a)
|
|
|
+
|
|
|
+ # =========================
|
|
|
+ # 8. final output
|
|
|
+ # =========================
|
|
|
+ codebooks = torch.stack(codebooks, dim=1)
|
|
|
+
|
|
|
+ return codebooks.T
|
|
|
|
|
|
def decode_n_tokens(
|
|
|
model: DualARTransformer,
|
|
|
@@ -212,7 +308,7 @@ def decode_n_tokens(
|
|
|
|
|
|
for i in tqdm(range(num_new_tokens)):
|
|
|
f_start = time.perf_counter()
|
|
|
- with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
|
|
|
+ with sdpa_kernel(SDPBackend.MATH):
|
|
|
next_token = decode_one_token(
|
|
|
model=model,
|
|
|
x=cur_token,
|
|
|
@@ -392,11 +488,6 @@ def generate(
|
|
|
|
|
|
def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
|
|
|
- torch.backends.cuda.enable_flash_sdp(False)
|
|
|
- torch.backends.cuda.enable_math_sdp(True)
|
|
|
- torch.backends.cuda.enable_mem_efficient_sdp(True)
|
|
|
- torch.backends.cuda.enable_cudnn_sdp(True)
|
|
|
-
|
|
|
model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
|
|
|
|
|
|
logger.info(f"precision: {precision.__class__.__name__}")
|