|
|
@@ -298,7 +298,7 @@ def encode_tokens(
|
|
|
tokens = torch.tensor([tokens], dtype=torch.int, device=device)
|
|
|
|
|
|
# Codebooks
|
|
|
- zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
|
|
|
+ zeros = torch.zeros((8, tokens.size(1)), dtype=torch.int, device=device)
|
|
|
prompt = torch.cat((tokens, zeros), dim=0)
|
|
|
|
|
|
if prompt_tokens is None:
|
|
|
@@ -368,7 +368,7 @@ def load_model(config_name, checkpoint_path, device, precision):
|
|
|
"--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
|
|
|
)
|
|
|
@click.option("--num-samples", type=int, default=1)
|
|
|
-@click.option("--max_new_tokens", type=int, default=0)
|
|
|
+@click.option("--max-new-tokens", type=int, default=0)
|
|
|
@click.option("--top-k", type=int, default=None)
|
|
|
@click.option("--top-p", type=float, default=0.5)
|
|
|
@click.option("--repetition-penalty", type=float, default=1.5)
|