Forráskód Böngészése

Fix breakdown infer (#534)

* fully support ormsgpack

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* dependency

* torch==2.4.1 windows compilable

* Update docs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove autorerank

* api usage

* back slash

* fix docs

* Fix infer warmup params

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* max_new_tokens=1024

* Fix break down infer

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 éve
szülő
commit
6e95d2ae3e
1 módosított fájl, 5 hozzáadás és 2 törlés
  1. 5 2
      tools/llama/generate.py

+ 5 - 2
tools/llama/generate.py

@@ -605,7 +605,7 @@ def launch_thread_safe_queue(
     multiple=True,
 )
 @click.option("--num-samples", type=int, default=1)
-@click.option("--max-new-tokens", type=int, default=0)
+@click.option("--max-new-tokens", type=int, default=1024)
 @click.option("--top-p", type=float, default=0.7)
 @click.option("--repetition-penalty", type=float, default=1.2)
 @click.option("--temperature", type=float, default=0.7)
@@ -650,7 +650,10 @@ def main(
     model, decode_one_token = load_model(
         checkpoint_path, device, precision, compile=compile
     )
-
+    with torch.device(device):
+        model.setup_caches(
+            max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
+        )
     if torch.cuda.is_available():
         torch.cuda.synchronize()