Browse Source

feat:修改sdp backend

zhaohaipeng 1 month ago
parent
commit
a62cc45b4a
1 changed files with 1 additions and 1 deletions
  1. 1 1
      fish_speech/models/text2semantic/inference.py

+ 1 - 1
fish_speech/models/text2semantic/inference.py

@@ -308,7 +308,7 @@ def decode_n_tokens(
 
 
     for i in tqdm(range(num_new_tokens)):
     for i in tqdm(range(num_new_tokens)):
         f_start = time.perf_counter()
         f_start = time.perf_counter()
-        with sdpa_kernel(SDPBackend.MATH):
+        with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
             next_token = decode_one_token(
             next_token = decode_one_token(
                 model=model,
                 model=model,
                 x=cur_token,
                 x=cur_token,