|
@@ -522,7 +522,7 @@ def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
decode_one_token = torch.compile(
|
|
decode_one_token = torch.compile(
|
|
|
decode_one_token,
|
|
decode_one_token,
|
|
|
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
|
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
|
|
- mode="reduce-overhead" if torch.cuda.is_available() else None,
|
|
|
|
|
|
|
+ mode="default" if torch.cuda.is_available() else None,
|
|
|
fullgraph=False,
|
|
fullgraph=False,
|
|
|
)
|
|
)
|
|
|
|
|
|