Browse Source

feat:添加耗时统计

zhaohaipeng 1 month ago
parent
commit
ed30624b1f
1 changed files with 6 additions and 4 deletions
  1. 6 4
      fish_speech/models/text2semantic/inference.py

+ 6 - 4
fish_speech/models/text2semantic/inference.py

@@ -254,8 +254,6 @@ def decode_n_tokens_optimized(
         semantic_logit_bias: torch.Tensor,
         semantic_logit_bias: torch.Tensor,
         audio_masks: torch.Tensor,
         audio_masks: torch.Tensor,
         audio_parts: torch.Tensor,
         audio_parts: torch.Tensor,
-        previous_tokens: torch.Tensor,
-        im_end_id: Any,
         decode_one_token=decode_one_token_ar,
         decode_one_token=decode_one_token_ar,
 ):
 ):
     """
     """
@@ -271,6 +269,11 @@ def decode_n_tokens_optimized(
     # =========================
     # =========================
     # 1. ring buffer index (替代 roll)
     # 1. ring buffer index (替代 roll)
     # =========================
     # =========================
+    previous_tokens = torch.zeros(
+        (model.config.num_codebooks + 1, RAS_WIN_SIZE),
+        dtype=torch.int,
+        device=cur_token.device,
+    )
     history_len = previous_tokens.size(1)
     history_len = previous_tokens.size(1)
     write_idx = history_len - 1
     write_idx = history_len - 1
 
 
@@ -280,6 +283,7 @@ def decode_n_tokens_optimized(
     # 2. precompute reshape shape
     # 2. precompute reshape shape
     # =========================
     # =========================
     batch = 1
     batch = 1
+    im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
 
 
     # =========================
     # =========================
     # 3. main loop
     # 3. main loop
@@ -483,8 +487,6 @@ def generate(
         semantic_logit_bias=semantic_logit_bias,
         semantic_logit_bias=semantic_logit_bias,
         audio_masks=audio_masks,
         audio_masks=audio_masks,
         audio_parts=audio_parts,
         audio_parts=audio_parts,
-        im_end_id=im_end_id,
-        previous_tokens=previous_tokens,
         decode_one_token=decode_one_token,
         decode_one_token=decode_one_token,
     )
     )
     seq = seq[:, : T + 1 + x.size(1)]
     seq = seq[:, : T + 1 + x.size(1)]