|
|
@@ -605,7 +605,7 @@ def launch_thread_safe_queue(
|
|
|
multiple=True,
|
|
|
)
|
|
|
@click.option("--num-samples", type=int, default=1)
|
|
|
-@click.option("--max-new-tokens", type=int, default=0)
|
|
|
+@click.option("--max-new-tokens", type=int, default=1024)
|
|
|
@click.option("--top-p", type=float, default=0.7)
|
|
|
@click.option("--repetition-penalty", type=float, default=1.2)
|
|
|
@click.option("--temperature", type=float, default=0.7)
|
|
|
@@ -650,7 +650,10 @@ def main(
|
|
|
model, decode_one_token = load_model(
|
|
|
checkpoint_path, device, precision, compile=compile
|
|
|
)
|
|
|
-
|
|
|
+ with torch.device(device):
|
|
|
+ model.setup_caches(
|
|
|
+ max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
|
|
|
+ )
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.synchronize()
|
|
|
|