|
@@ -308,7 +308,7 @@ def decode_n_tokens(
|
|
|
|
|
|
|
|
for i in tqdm(range(num_new_tokens)):
|
|
for i in tqdm(range(num_new_tokens)):
|
|
|
f_start = time.perf_counter()
|
|
f_start = time.perf_counter()
|
|
|
- with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
|
|
|
|
|
|
|
+ with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
|
|
|
next_token = decode_one_token(
|
|
next_token = decode_one_token(
|
|
|
model=model,
|
|
model=model,
|
|
|
x=cur_token,
|
|
x=cur_token,
|