Quellcode durchsuchen

Fix inference speed. (#928)

* Fix inference speed.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
PoTaTo vor 1 Jahr
Ursprung
Commit
3eb2f3299f
1 geänderte Dateien mit 12 neuen und 13 gelöschten Zeilen
  1. 12 13
      fish_speech/models/text2semantic/inference.py

+ 12 - 13
fish_speech/models/text2semantic/inference.py

@@ -13,7 +13,6 @@ import torch
 import torch._dynamo.config
 import torch._inductor.config
 from loguru import logger
-from torch.nn.attention import SDPBackend, sdpa_kernel
 from tqdm import tqdm
 from transformers import AutoTokenizer
 
@@ -24,12 +23,7 @@ from fish_speech.conversation import (
     TextPart,
     VQPart,
 )
-from fish_speech.models.text2semantic.llama import (
-    BaseModelArgs,
-    BaseTransformer,
-    DualARTransformer,
-    NaiveTransformer,
-)
+from fish_speech.models.text2semantic.llama import BaseModelArgs
 from fish_speech.text import clean_text, split_text
 from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
 
@@ -42,6 +36,15 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
     torch._inductor.config.fx_graph_cache = True
 
 
+from torch.nn.attention import SDPBackend, sdpa_kernel
+
+from fish_speech.models.text2semantic.llama import (
+    BaseTransformer,
+    DualARTransformer,
+    NaiveTransformer,
+)
+
+
 def multinomial_sample_one_no_sync(
     probs_sort,
 ):  # Does multinomial sampling without a cuda synchronization
@@ -369,12 +372,8 @@ def decode_n_tokens(
             window = previous_tokens[:, i - win_size : i]
 
         with (
-            sdpa_kernel(
-                [
-                    SDPBackend.FLASH_ATTENTION,
-                    SDPBackend.EFFICIENT_ATTENTION,
-                    SDPBackend.MATH,
-                ]
+            torch.backends.cuda.sdp_kernel(
+                enable_flash=False, enable_mem_efficient=False, enable_math=True
             )
             if torch.cuda.is_available()
             else nullcontext()