zhaohaipeng пре 1 месец
родитељ
комит
040b5a440d
1 измењених фајлова са 3 додато и 2 уклоњено
  1. 3 2
      fish_speech/models/text2semantic/inference.py

+ 3 - 2
fish_speech/models/text2semantic/inference.py

@@ -208,7 +208,7 @@ def decode_n_tokens(
 
 
     for i in tqdm(range(num_new_tokens)):
     for i in tqdm(range(num_new_tokens)):
         f_start = time.perf_counter()
         f_start = time.perf_counter()
-        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+        with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
             next_token = decode_one_token(
             next_token = decode_one_token(
                 model=model,
                 model=model,
                 x=cur_token,
                 x=cur_token,
@@ -750,10 +750,11 @@ def generate_long(
                 audio_masks=audio_masks,
                 audio_masks=audio_masks,
                 audio_parts=audio_parts,
                 audio_parts=audio_parts,
                 decode_one_token=decode_one_token,
                 decode_one_token=decode_one_token,
+                prompt_tokens=all_codes,
+
                 temperature=temperature,
                 temperature=temperature,
                 top_p=top_p,
                 top_p=top_p,
                 top_k=top_k,
                 top_k=top_k,
-                prompt_tokens=all_codes,
             )
             )
 
 
             if sample_idx == 0 and batch_idx == 0 and compile:
             if sample_idx == 0 and batch_idx == 0 and compile: