|
@@ -93,7 +93,7 @@ def sample(
|
|
|
return idx_next, probs
|
|
return idx_next, probs
|
|
|
|
|
|
|
|
|
|
|
|
|
-def decode_one_token_ar_old(
|
|
|
|
|
|
|
+def decode_one_token_ar(
|
|
|
model: DualARTransformer,
|
|
model: DualARTransformer,
|
|
|
x: torch.Tensor,
|
|
x: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
input_pos: torch.Tensor,
|
|
@@ -181,7 +181,7 @@ def decode_one_token_ar_old(
|
|
|
return codebooks.T
|
|
return codebooks.T
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
@torch.inference_mode()
|
|
|
-def decode_one_token_ar(
|
|
|
|
|
|
|
+def decode_one_token_ar_optimize(
|
|
|
model,
|
|
model,
|
|
|
x,
|
|
x,
|
|
|
input_pos,
|
|
input_pos,
|
|
@@ -522,8 +522,8 @@ def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
decode_one_token = torch.compile(
|
|
decode_one_token = torch.compile(
|
|
|
decode_one_token,
|
|
decode_one_token,
|
|
|
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
|
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
|
|
- mode="default" if torch.cuda.is_available() else None,
|
|
|
|
|
- fullgraph=True,
|
|
|
|
|
|
|
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
|
|
|
|
|
+ fullgraph=False,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
return model.eval(), decode_one_token
|
|
return model.eval(), decode_one_token
|