Przeglądaj źródła

feat:关闭compile

zhaohaipeng 1 miesiąc temu
rodzic
commit
a18b1ac088
1 zmienionych plików z 15 dodań i 41 usunięć
  1. 15 41
      fish_speech/models/text2semantic/inference.py

+ 15 - 41
fish_speech/models/text2semantic/inference.py

@@ -7,10 +7,11 @@ import traceback
 from copy import deepcopy
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Callable, Literal, Optional, Tuple, Union, Any
+from typing import Callable, Literal, Optional, Tuple, Union
 
 import click
 import numpy as np
+import torch
 import torch._inductor.config
 from loguru import logger
 from tqdm import tqdm
@@ -29,10 +30,13 @@ torch._inductor.config.triton.unique_kernel_names = True
 if hasattr(torch._inductor.config, "fx_graph_cache"):
     torch._inductor.config.fx_graph_cache = True
 
+
 from torch.nn.attention import SDPBackend, sdpa_kernel
 
 from fish_speech.models.text2semantic.llama import (
+    BaseTransformer,
     DualARTransformer,
+    NaiveTransformer,
 )
 
 
@@ -245,16 +249,15 @@ def decode_n_tokens(
 @torch.no_grad()
 @torch.inference_mode()
 def generate(
-        *,
-        model: DualARTransformer,
-        prompt: torch.Tensor,
-        max_new_tokens: int,
-        audio_masks: torch.Tensor,
-        audio_parts: torch.Tensor,
-        prompt_tokens = None,
-        decode_one_token=decode_one_token_ar,
-        num_samples: int = 1,
-        **sampling_kwargs,
+    *,
+    model: DualARTransformer,
+    prompt: torch.Tensor,
+    max_new_tokens: int,
+    audio_masks: torch.Tensor,
+    audio_parts: torch.Tensor,
+    decode_one_token=decode_one_token_ar,
+    num_samples: int = 1,
+    **sampling_kwargs,
 ):
     """
     Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
@@ -337,6 +340,7 @@ def generate(
     step7 = time.perf_counter()
 
     prefill_decode = decode_one_token_ar
+
     first_token = prefill_decode(
         model,
         prompt.view(1, codebook_dim, -1),
@@ -355,32 +359,6 @@ def generate(
     input_pos = torch.tensor([T], device=device, dtype=torch.int)
     step9 = time.perf_counter()
 
-    im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
-    codebook_dim = 1 + model.config.num_codebooks
-    window_size = 64
-
-    previous_tokens = torch.zeros(
-        (1, window_size, codebook_dim),
-        device=device,
-        dtype=first_token.dtype,
-    )
-
-    # =========================
-    # 1. warm start prompt
-    # =========================
-    if prompt_tokens is not None:
-        # 确保 shape = [B, T, C]
-        if prompt_tokens.dim() == 2:
-            prompt_tokens = prompt_tokens.unsqueeze(0)
-
-        T = min(prompt_tokens.size(1), window_size)
-        previous_tokens[:, -T:] = prompt_tokens[:, -T:]
-
-    # =========================
-    # 2. insert first token
-    # =========================
-    previous_tokens[:, -1, :] = first_token.view(codebook_dim)
-
     x = decode_n_tokens(
         model,
         first_token.view(1, codebook_dim, -1),
@@ -627,8 +605,6 @@ def generate_long(
     # Build base conversation with system message
     base_conversation = Conversation()
 
-    all_codes = None
-
     if use_prompt:
         # Auto-add speaker tags to prompt texts that don't have them
         tagged_prompt_text = []
@@ -750,8 +726,6 @@ def generate_long(
                 audio_masks=audio_masks,
                 audio_parts=audio_parts,
                 decode_one_token=decode_one_token,
-                prompt_tokens=all_codes,
-
                 temperature=temperature,
                 top_p=top_p,
                 top_k=top_k,