|
|
@@ -237,6 +237,16 @@ def generate(
|
|
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
|
T = prompt.size(1)
|
|
|
|
|
|
+ if max_new_tokens:
|
|
|
+ if T + max_new_tokens > model.config.max_seq_len:
|
|
|
+ max_new_tokens = model.config.max_seq_len - T
|
|
|
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
|
|
+
|
|
|
+ T_new = T + max_new_tokens
|
|
|
+ else:
|
|
|
+ T_new = model.config.max_seq_len
|
|
|
+ max_new_tokens = T_new - T
|
|
|
+
|
|
|
device, dtype = prompt.device, prompt.dtype
|
|
|
|
|
|
codebook_dim = 1 + model.config.num_codebooks
|
|
|
@@ -565,7 +575,9 @@ def launch_thread_safe_queue(
|
|
|
)
|
|
|
with torch.device(device):
|
|
|
model.setup_caches(
|
|
|
- max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
|
|
|
+ max_batch_size=1,
|
|
|
+ max_seq_len=model.config.max_seq_len,
|
|
|
+ dtype=next(model.parameters()).dtype,
|
|
|
)
|
|
|
init_event.set()
|
|
|
|
|
|
@@ -607,7 +619,7 @@ def launch_thread_safe_queue(
|
|
|
multiple=True,
|
|
|
)
|
|
|
@click.option("--num-samples", type=int, default=1)
|
|
|
-@click.option("--max-new-tokens", type=int, default=1024)
|
|
|
+@click.option("--max-new-tokens", type=int, default=0)
|
|
|
@click.option("--top-p", type=float, default=0.7)
|
|
|
@click.option("--repetition-penalty", type=float, default=1.2)
|
|
|
@click.option("--temperature", type=float, default=0.7)
|
|
|
@@ -654,7 +666,9 @@ def main(
|
|
|
)
|
|
|
with torch.device(device):
|
|
|
model.setup_caches(
|
|
|
- max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
|
|
|
+ max_batch_size=1,
|
|
|
+ max_seq_len=model.config.max_seq_len,
|
|
|
+ dtype=next(model.parameters()).dtype,
|
|
|
)
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.synchronize()
|