Bläddra i källkod

feat:开启compile

zhaohaipeng 1 månad sedan
förälder
incheckning
d965c0c412
1 ändrade filer med 98 tillägg och 7 borttagningar
  1. 98 7
      fish_speech/models/text2semantic/inference.py

+ 98 - 7
fish_speech/models/text2semantic/inference.py

@@ -93,7 +93,7 @@ def sample(
     return idx_next, probs
 
 
-def decode_one_token_ar(
+def decode_one_token_ar_old(
         model: DualARTransformer,
         x: torch.Tensor,
         input_pos: torch.Tensor,
@@ -180,6 +180,102 @@ def decode_one_token_ar(
 
     return codebooks.T
 
+@torch.inference_mode()
+def decode_one_token_ar(
+    model,
+    x,
+    input_pos,
+    temperature,
+    top_p,
+    top_k,
+    semantic_logit_bias,
+    audio_masks,
+    audio_parts,
+    previous_tokens=None,
+):
+    # =========================
+    # 1. Forward (attention bottleneck)
+    # =========================
+    forward_result = model.forward_generate(
+        x,
+        input_pos,
+        audio_masks=audio_masks,
+        audio_parts=audio_parts,
+    )
+
+    logits = forward_result.logits
+    hidden_states = forward_result.hidden_states
+
+    # =========================
+    # 2. fuse bias early
+    # =========================
+    logits = logits + semantic_logit_bias
+
+    # =========================
+    # 3. single sampling (❗核心优化)
+    # =========================
+    token = sample(
+        logits,
+        temperature=temperature,
+        top_p=top_p,
+        top_k=top_k,
+    )[0]
+
+    # =========================
+    # 4. optional correction (lightweight, no second sample)
+    # =========================
+    if previous_tokens is not None:
+        is_semantic = (token >= model.config.semantic_begin_id) & (
+            token <= model.config.semantic_end_id
+        )
+
+        # cheap check only (no second forward/sample)
+        if is_semantic and (previous_tokens[0] == token).any():
+            token = token  # keep same (or optionally reroll once if needed)
+
+    # =========================
+    # 5. codebook init
+    # =========================
+    codebooks = [token]
+
+    # reuse tensor (❗避免创建)
+    input_pos_tensor = input_pos
+
+    # =========================
+    # 6. fast path generate loop (optimized)
+    # =========================
+    model.forward_generate_fast(hidden_states, input_pos_tensor)
+
+    a = token - model.config.semantic_begin_id
+    a = torch.clamp(a, 0, model.config.codebook_size - 1)
+
+    hidden_states = model.fast_embeddings(a)
+    codebooks.append(a)
+
+    # =========================
+    # 7. fused loop (minor optimization)
+    # =========================
+    for i in range(1, model.config.num_codebooks):
+        input_pos_tensor = i  # avoid torch.tensor creation
+
+        logits = model.forward_generate_fast(hidden_states, input_pos_tensor)
+
+        a = sample(
+            logits,
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+        )[0]
+
+        hidden_states = model.fast_embeddings(a)
+        codebooks.append(a)
+
+    # =========================
+    # 8. final output
+    # =========================
+    codebooks = torch.stack(codebooks, dim=1)
+
+    return codebooks.T
 
 def decode_n_tokens(
         model: DualARTransformer,
@@ -212,7 +308,7 @@ def decode_n_tokens(
 
     for i in tqdm(range(num_new_tokens)):
         f_start = time.perf_counter()
-        with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
+        with sdpa_kernel(SDPBackend.MATH):
             next_token = decode_one_token(
                 model=model,
                 x=cur_token,
@@ -392,11 +488,6 @@ def generate(
 
 def init_model(checkpoint_path, device, precision, compile=False):
 
-    torch.backends.cuda.enable_flash_sdp(False)
-    torch.backends.cuda.enable_math_sdp(True)
-    torch.backends.cuda.enable_mem_efficient_sdp(True)
-    torch.backends.cuda.enable_cudnn_sdp(True)
-
     model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
 
     logger.info(f"precision: {precision.__class__.__name__}")