zhaohaipeng 1 месяц назад
Родитель
Сommit
040b5a440d
1 измененных файлов с 3 добавлено и 2 удалено
  1. 3 2
      fish_speech/models/text2semantic/inference.py

+ 3 - 2
fish_speech/models/text2semantic/inference.py

@@ -208,7 +208,7 @@ def decode_n_tokens(
 
     for i in tqdm(range(num_new_tokens)):
         f_start = time.perf_counter()
-        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+        with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
             next_token = decode_one_token(
                 model=model,
                 x=cur_token,
@@ -750,10 +750,11 @@ def generate_long(
                 audio_masks=audio_masks,
                 audio_parts=audio_parts,
                 decode_one_token=decode_one_token,
+                prompt_tokens=all_codes,
+
                 temperature=temperature,
                 top_p=top_p,
                 top_k=top_k,
-                prompt_tokens=all_codes,
             )
 
             if sample_idx == 0 and batch_idx == 0 and compile: