|
|
@@ -393,7 +393,7 @@ def generate(
|
|
|
def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
|
|
|
torch.backends.cuda.enable_flash_sdp(False)
|
|
|
- torch.backends.cuda.enable_math_sdp(False)
|
|
|
+ torch.backends.cuda.enable_math_sdp(True)
|
|
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
|
|
torch.backends.cuda.enable_cudnn_sdp(True)
|
|
|
|
|
|
@@ -421,11 +421,18 @@ def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
|
|
|
if compile:
|
|
|
logger.info("Compiling function...")
|
|
|
+ # decode_one_token = torch.compile(
|
|
|
+ # decode_one_token,
|
|
|
+ # backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
|
|
+ # mode="default" if torch.cuda.is_available() else None,
|
|
|
+ # fullgraph=True,
|
|
|
+ # )
|
|
|
+
|
|
|
decode_one_token = torch.compile(
|
|
|
decode_one_token,
|
|
|
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
|
|
- mode="default" if torch.cuda.is_available() else None,
|
|
|
- fullgraph=True,
|
|
|
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
|
|
|
+ fullgraph=False,
|
|
|
)
|
|
|
|
|
|
return model.eval(), decode_one_token
|