|
|
@@ -94,7 +94,6 @@ def decode_one_token_ar(
|
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
|
x = model.forward_generate(x, input_pos)
|
|
|
- logits = x.token_logits
|
|
|
|
|
|
sampling_kwargs_main = sampling_kwargs.copy()
|
|
|
sampling_kwargs_main["temperature"] = 0.1
|
|
|
@@ -103,7 +102,7 @@ def decode_one_token_ar(
|
|
|
|
|
|
codebooks = [
|
|
|
sample(
|
|
|
- logits,
|
|
|
+ x.logits,
|
|
|
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
|
**sampling_kwargs_main,
|
|
|
)[0]
|
|
|
@@ -142,7 +141,6 @@ def decode_one_token_naive(
|
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
|
x = model.forward_generate(x, input_pos)
|
|
|
- logits = x.token_logits
|
|
|
|
|
|
sampling_kwargs_main = sampling_kwargs.copy()
|
|
|
sampling_kwargs_main["temperature"] = 0.1
|
|
|
@@ -151,7 +149,7 @@ def decode_one_token_naive(
|
|
|
|
|
|
codebooks = [
|
|
|
sample(
|
|
|
- logits,
|
|
|
+ x.logits,
|
|
|
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
|
**sampling_kwargs_main,
|
|
|
)[0]
|