|
@@ -2,6 +2,7 @@ import os
|
|
|
import queue
|
|
import queue
|
|
|
import threading
|
|
import threading
|
|
|
import time
|
|
import time
|
|
|
|
|
+from contextlib import nullcontext
|
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from typing import Literal, Optional, Tuple, Union
|
|
from typing import Literal, Optional, Tuple, Union
|
|
@@ -181,8 +182,12 @@ def decode_n_tokens(
|
|
|
else:
|
|
else:
|
|
|
window = previous_tokens[:, i - win_size : i]
|
|
window = previous_tokens[:, i - win_size : i]
|
|
|
|
|
|
|
|
- with torch.backends.cuda.sdp_kernel(
|
|
|
|
|
- enable_flash=False, enable_mem_efficient=False, enable_math=True
|
|
|
|
|
|
|
+ with (
|
|
|
|
|
+ torch.backends.cuda.sdp_kernel(
|
|
|
|
|
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
|
|
|
|
|
+ )
|
|
|
|
|
+ if torch.cuda.is_available()
|
|
|
|
|
+ else nullcontext()
|
|
|
): # Actually better for Inductor to codegen attention here
|
|
): # Actually better for Inductor to codegen attention here
|
|
|
next_token = decode_one_token(
|
|
next_token = decode_one_token(
|
|
|
model=model,
|
|
model=model,
|