|
@@ -94,6 +94,7 @@ def decode_one_token_ar(
|
|
|
**sampling_kwargs,
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
x = model.forward_generate(x, input_pos)
|
|
x = model.forward_generate(x, input_pos)
|
|
|
|
|
+ logits = x.token_logits
|
|
|
|
|
|
|
|
sampling_kwargs_main = sampling_kwargs.copy()
|
|
sampling_kwargs_main = sampling_kwargs.copy()
|
|
|
sampling_kwargs_main["temperature"] = 0.1
|
|
sampling_kwargs_main["temperature"] = 0.1
|