|
@@ -341,7 +341,8 @@ def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
decode_one_token = torch.compile(
|
|
decode_one_token = torch.compile(
|
|
|
decode_one_token,
|
|
decode_one_token,
|
|
|
# mode="max-autotune-no-cudagraphs",
|
|
# mode="max-autotune-no-cudagraphs",
|
|
|
- mode="reduce-overhead",
|
|
|
|
|
|
|
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
|
|
|
|
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
|
|
|
fullgraph=True,
|
|
fullgraph=True,
|
|
|
)
|
|
)
|
|
|
|
|
|