Kaynağa Gözat

[fix]:fix future warning: sdp_kernel

fix future warning: sdp_kernel
Ardent Illumina 1 yıl önce
ebeveyn
işleme
b3c39d3205
1 değiştirilmiş dosya ile 13 ekleme ve 12 silme
  1. 13 12
      fish_speech/models/text2semantic/inference.py

+ 13 - 12
fish_speech/models/text2semantic/inference.py

@@ -13,6 +13,7 @@ import torch
 import torch._dynamo.config
 import torch._inductor.config
 from loguru import logger
+from torch.nn.attention import SDPBackend, sdpa_kernel
 from tqdm import tqdm
 from transformers import AutoTokenizer
 
@@ -23,7 +24,12 @@ from fish_speech.conversation import (
     TextPart,
     VQPart,
 )
-from fish_speech.models.text2semantic.llama import BaseModelArgs
+from fish_speech.models.text2semantic.llama import (
+    BaseModelArgs,
+    BaseTransformer,
+    DualARTransformer,
+    NaiveTransformer,
+)
 from fish_speech.text import clean_text, split_text
 from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
 
@@ -36,15 +42,6 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
     torch._inductor.config.fx_graph_cache = True
 
 
-from torch.nn.attention import SDPBackend, sdpa_kernel
-
-from fish_speech.models.text2semantic.llama import (
-    BaseTransformer,
-    DualARTransformer,
-    NaiveTransformer,
-)
-
-
 def multinomial_sample_one_no_sync(
     probs_sort,
 ):  # Does multinomial sampling without a cuda synchronization
@@ -372,8 +369,12 @@ def decode_n_tokens(
             window = previous_tokens[:, i - win_size : i]
 
         with (
-            torch.backends.cuda.sdp_kernel(
-                enable_flash=False, enable_mem_efficient=False, enable_math=True
+            sdpa_kernel(
+                [
+                    SDPBackend.FLASH_ATTENTION,
+                    SDPBackend.EFFICIENT_ATTENTION,
+                    SDPBackend.MATH,
+                ]
             )
             if torch.cuda.is_available()
             else nullcontext()