|
|
@@ -43,11 +43,11 @@ def logits_to_probs(
|
|
|
):
|
|
|
if previous_tokens is not None and repetition_penalty != 1.0:
|
|
|
previous_tokens = previous_tokens.long()
|
|
|
- score = torch.gather(logits, dim=-1, index=previous_tokens)
|
|
|
+ score = torch.gather(logits, dim=0, index=previous_tokens)
|
|
|
score = torch.where(
|
|
|
score < 0, score * repetition_penalty, score / repetition_penalty
|
|
|
)
|
|
|
- logits.scatter_(dim=-1, index=previous_tokens, src=score)
|
|
|
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
|
|
|
|
|
|
if top_p is not None and top_p < 1.0:
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
@@ -110,7 +110,9 @@ def decode_one_token(
|
|
|
codebooks.append(
|
|
|
sample(
|
|
|
logits.codebook_logits[:, :, i],
|
|
|
- previous_tokens=previous_tokens[i + 1],
|
|
|
+ previous_tokens=previous_tokens[i + 1]
|
|
|
+ if previous_tokens is not None
|
|
|
+ else None,
|
|
|
**sampling_kwargs,
|
|
|
)[0]
|
|
|
)
|
|
|
@@ -162,6 +164,13 @@ def decode_n_tokens(
|
|
|
)
|
|
|
|
|
|
for i in tqdm(range(num_new_tokens)):
|
|
|
+ # We need to get windowed repeat penalty
|
|
|
+ win_size = 16
|
|
|
+ if i < win_size:
|
|
|
+ window = previous_tokens[:, :win_size]
|
|
|
+ else:
|
|
|
+ window = previous_tokens[:, i - win_size : i]
|
|
|
+
|
|
|
with torch.backends.cuda.sdp_kernel(
|
|
|
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
|
|
): # Actually better for Inductor to codegen attention here
|
|
|
@@ -169,7 +178,7 @@ def decode_n_tokens(
|
|
|
model,
|
|
|
cur_token,
|
|
|
input_pos,
|
|
|
- previous_tokens,
|
|
|
+ window,
|
|
|
**sampling_kwargs,
|
|
|
)
|
|
|
|
|
|
@@ -340,7 +349,7 @@ def load_model(config_name, checkpoint_path, device, precision):
|
|
|
@click.option("--max_new_tokens", type=int, default=0)
|
|
|
@click.option("--top-k", type=int, default=None)
|
|
|
@click.option("--top-p", type=float, default=0.5)
|
|
|
-@click.option("--repetition-penalty", type=float, default=1.1)
|
|
|
+@click.option("--repetition-penalty", type=float, default=1.5)
|
|
|
@click.option("--temperature", type=float, default=0.7)
|
|
|
@click.option(
|
|
|
"--checkpoint-path",
|