Explorar el Código

feat:添加耗时统计

zhaohaipeng hace 1 mes
padre
commit
5e6e36d912
Se han modificado 1 ficheros con 2 adiciones y 97 borrados
  1. 2 97
      fish_speech/models/text2semantic/inference.py

+ 2 - 97
fish_speech/models/text2semantic/inference.py

@@ -208,7 +208,7 @@ def decode_n_tokens(
 
     for i in tqdm(range(num_new_tokens)):
         f_start = time.perf_counter()
-        with sdpa_kernel(SDPBackend.MATH):
+        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
             next_token = decode_one_token(
                 model=model,
                 x=cur_token,
@@ -242,101 +242,6 @@ def decode_n_tokens(
 
     return torch.cat(new_tokens, dim=1)
 
-
-def decode_n_tokens_optimized(
-        model: DualARTransformer,
-        cur_token: torch.Tensor,
-        input_pos: torch.Tensor,
-        num_new_tokens: int,
-        temperature: torch.Tensor,
-        top_p: torch.Tensor,
-        top_k: int,
-        semantic_logit_bias: torch.Tensor,
-        audio_masks: torch.Tensor,
-        audio_parts: torch.Tensor,
-        decode_one_token=decode_one_token_ar,
-):
-    """
-    Optimized version:
-    - no roll (ring buffer)
-    - flash attention
-    - reduced view/reshape
-    """
-
-    device = cur_token.device
-    num_streams = model.config.num_codebooks + 1
-
-    # =========================
-    # 1. ring buffer index (替代 roll)
-    # =========================
-    previous_tokens = torch.zeros(
-        (model.config.num_codebooks + 1, RAS_WIN_SIZE),
-        dtype=torch.int,
-        device=cur_token.device,
-    )
-    history_len = previous_tokens.size(1)
-    write_idx = history_len - 1
-
-    new_tokens = []
-
-    # =========================
-    # 2. precompute reshape shape
-    # =========================
-    batch = 1
-    im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
-
-    # =========================
-    # 3. main loop
-    # =========================
-    for i in range(num_new_tokens):
-
-        # ⚡ use flash attention (重要优化)
-        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
-            next_token = decode_one_token(
-                model=model,
-                x=cur_token,
-                input_pos=input_pos,
-                previous_tokens=previous_tokens,
-                temperature=temperature,
-                top_p=top_p,
-                top_k=top_k,
-                semantic_logit_bias=semantic_logit_bias,
-                audio_masks=audio_masks,
-                audio_parts=audio_parts,
-            ).clone()
-
-        # =========================
-        # 4. update position
-        # =========================
-        input_pos += 1
-
-        # =========================
-        # 5. reshape once (reuse view logic)
-        # =========================
-        next_token_2d = next_token.view(num_streams, -1)
-
-        cur_token = next_token_2d.unsqueeze(0)
-
-        # =========================
-        # 6. ring buffer update (NO roll)
-        # =========================
-        previous_tokens[:, write_idx] = next_token_2d[:, 0]
-        write_idx = (write_idx + 1) % history_len
-
-        # =========================
-        # 7. store output
-        # =========================
-        new_tokens.append(next_token)
-
-        # =========================
-        # 8. EOS check
-        # =========================
-        if cur_token[0, 0, -1] == im_end_id:
-            break
-
-    return new_tokens
-
-
 @torch.no_grad()
 @torch.inference_mode()
 def generate(
@@ -476,7 +381,7 @@ def generate(
     # =========================
     previous_tokens[:, -1, :] = first_token.view(codebook_dim)
 
-    x = decode_n_tokens_optimized(
+    x = decode_n_tokens(
         model,
         first_token.view(1, codebook_dim, -1),
         input_pos,