@@ -97,7 +97,7 @@ def decode_one_token(
codebooks = [
sample(
logits.token_logits,
- previous_tokens=previous_tokens[0],
+ previous_tokens=None, # Disable repetition penalty for the token codebook
**sampling_kwargs,
)[0]
]