|
@@ -94,15 +94,20 @@ def decode_one_token_ar(
|
|
|
**sampling_kwargs,
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
x = model.forward_generate(x, input_pos)
|
|
x = model.forward_generate(x, input_pos)
|
|
|
|
|
+
|
|
|
|
|
+ sampling_kwargs_main = sampling_kwargs.copy()
|
|
|
|
|
+ sampling_kwargs_main["temperature"] = 0.1
|
|
|
|
|
+ sampling_kwargs_main["top_p"] = 0.1
|
|
|
|
|
+ sampling_kwargs_main["repetition_penalty"] = 1.0
|
|
|
|
|
+
|
|
|
codebooks = [
|
|
codebooks = [
|
|
|
sample(
|
|
sample(
|
|
|
- x.logits,
|
|
|
|
|
- previous_tokens=(
|
|
|
|
|
- previous_tokens[0] if previous_tokens is not None else None
|
|
|
|
|
- ), # Disable repetition penalty for the token codebook
|
|
|
|
|
- **sampling_kwargs,
|
|
|
|
|
|
|
+ logits,
|
|
|
|
|
+ previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
|
|
|
+ **sampling_kwargs_main,
|
|
|
)[0]
|
|
)[0]
|
|
|
]
|
|
]
|
|
|
|
|
+
|
|
|
x = x.hidden_states
|
|
x = x.hidden_states
|
|
|
|
|
|
|
|
# Cleanup the cache
|
|
# Cleanup the cache
|
|
@@ -136,12 +141,18 @@ def decode_one_token_naive(
|
|
|
**sampling_kwargs,
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
x = model.forward_generate(x, input_pos)
|
|
x = model.forward_generate(x, input_pos)
|
|
|
|
|
+ logits = x.token_logits
|
|
|
|
|
+
|
|
|
|
|
+ sampling_kwargs_main = sampling_kwargs.copy()
|
|
|
|
|
+ sampling_kwargs_main["temperature"] = 0.1
|
|
|
|
|
+ sampling_kwargs_main["top_p"] = 0.1
|
|
|
|
|
+ sampling_kwargs_main["repetition_penalty"] = 1.0
|
|
|
|
|
|
|
|
codebooks = [
|
|
codebooks = [
|
|
|
sample(
|
|
sample(
|
|
|
- x.token_logits,
|
|
|
|
|
|
|
+ logits,
|
|
|
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
|
- **sampling_kwargs,
|
|
|
|
|
|
|
+ **sampling_kwargs_main,
|
|
|
)[0]
|
|
)[0]
|
|
|
]
|
|
]
|
|
|
|
|
|