Procházet zdrojové kódy

Fix cache max_seq_len (#568)

* fix max_seq_len

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* another one

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix max new tokens

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama před 1 rokem
rodič
revize
ad55185ec3
4 změnil soubory, kde provedl 22 přidání a 7 odebrání
  1. 1 0
      install_env.bat
  2. 2 2
      tools/api.py
  3. 17 3
      tools/llama/generate.py
  4. 2 2
      tools/webui.py

+ 1 - 0
install_env.bat

@@ -144,6 +144,7 @@ call :download_and_install "triton_windows-0.1.0-py3-none-any.whl" ^
 
 endlocal
 echo "Environment Check: Success."
+:end
 pause
 
 goto :EOF

+ 2 - 2
tools/api.py

@@ -220,7 +220,7 @@ def inference(req: ServeTTSRequest):
         compile=args.compile,
         iterative_prompt=req.chunk_length > 0,
         chunk_length=req.chunk_length,
-        max_length=2048,
+        max_length=4096,
         prompt_tokens=prompt_tokens,
         prompt_text=prompt_texts,
     )
@@ -424,7 +424,7 @@ if __name__ == "__main__":
                 text="Hello world.",
                 references=[],
                 reference_id=None,
-                max_new_tokens=1024,
+                max_new_tokens=0,
                 chunk_length=200,
                 top_p=0.7,
                 repetition_penalty=1.2,

+ 17 - 3
tools/llama/generate.py

@@ -237,6 +237,16 @@ def generate(
     # create an empty tensor of the expected final shape and fill in the current tokens
     T = prompt.size(1)
 
+    if max_new_tokens:
+        if T + max_new_tokens > model.config.max_seq_len:
+            max_new_tokens = model.config.max_seq_len - T
+            logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
+
+        T_new = T + max_new_tokens
+    else:
+        T_new = model.config.max_seq_len
+        max_new_tokens = T_new - T
+
     device, dtype = prompt.device, prompt.dtype
 
     codebook_dim = 1 + model.config.num_codebooks
@@ -565,7 +575,9 @@ def launch_thread_safe_queue(
         )
         with torch.device(device):
             model.setup_caches(
-                max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
+                max_batch_size=1,
+                max_seq_len=model.config.max_seq_len,
+                dtype=next(model.parameters()).dtype,
             )
         init_event.set()
 
@@ -607,7 +619,7 @@ def launch_thread_safe_queue(
     multiple=True,
 )
 @click.option("--num-samples", type=int, default=1)
-@click.option("--max-new-tokens", type=int, default=1024)
+@click.option("--max-new-tokens", type=int, default=0)
 @click.option("--top-p", type=float, default=0.7)
 @click.option("--repetition-penalty", type=float, default=1.2)
 @click.option("--temperature", type=float, default=0.7)
@@ -654,7 +666,9 @@ def main(
     )
     with torch.device(device):
         model.setup_caches(
-            max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
+            max_batch_size=1,
+            max_seq_len=model.config.max_seq_len,
+            dtype=next(model.parameters()).dtype,
         )
     if torch.cuda.is_available():
         torch.cuda.synchronize()

+ 2 - 2
tools/webui.py

@@ -286,7 +286,7 @@ def build_app():
                             label=i18n("Maximum tokens per batch, 0 means no limit"),
                             minimum=0,
                             maximum=2048,
-                            value=1024,  # 0 means no limit
+                            value=0,  # 0 means no limit
                             step=8,
                         )
 
@@ -505,7 +505,7 @@ if __name__ == "__main__":
             enable_reference_audio=False,
             reference_audio=None,
             reference_text="",
-            max_new_tokens=1024,
+            max_new_tokens=0,
             chunk_length=200,
             top_p=0.7,
             repetition_penalty=1.2,