|
|
@@ -208,7 +208,7 @@ def decode_n_tokens(
|
|
|
|
|
|
for i in tqdm(range(num_new_tokens)):
|
|
|
f_start = time.perf_counter()
|
|
|
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
|
|
+ with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
|
|
|
next_token = decode_one_token(
|
|
|
model=model,
|
|
|
x=cur_token,
|
|
|
@@ -750,10 +750,11 @@ def generate_long(
|
|
|
audio_masks=audio_masks,
|
|
|
audio_parts=audio_parts,
|
|
|
decode_one_token=decode_one_token,
|
|
|
+ prompt_tokens=all_codes,
|
|
|
+
|
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
|
top_k=top_k,
|
|
|
- prompt_tokens=all_codes,
|
|
|
)
|
|
|
|
|
|
if sample_idx == 0 and batch_idx == 0 and compile:
|