@@ -94,6 +94,7 @@ def decode_one_token_ar(
**sampling_kwargs,
) -> torch.Tensor:
x = model.forward_generate(x, input_pos)
+ logits = x.token_logits
sampling_kwargs_main = sampling_kwargs.copy()
sampling_kwargs_main["temperature"] = 0.1