@@ -522,7 +522,7 @@ def init_model(checkpoint_path, device, precision, compile=False):
decode_one_token = torch.compile(
decode_one_token,
backend="inductor" if torch.cuda.is_available() else "aot_eager",
- mode="reduce-overhead" if torch.cuda.is_available() else None,
+ mode="default" if torch.cuda.is_available() else None,
fullgraph=False,
)