Lengyue 1 год назад
Родитель
Сommit
2e8c40f3b7
1 измененных файлов с 43 добавлено и 24 удалено
  1. 43 24
      tools/llama/generate.py

+ 43 - 24
tools/llama/generate.py

@@ -440,30 +440,42 @@ def generate_long(
     max_length: int = 2048,
     chunk_length: int = 150,
     speaker: Optional[str] = None,
-    prompt_text: Optional[str] = None,
-    prompt_tokens: Optional[torch.Tensor] = None,
+    prompt_text: Optional[str | list[str]] = None,
+    prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
 ):
     assert 0 < top_p <= 1, "top_p must be in (0, 1]"
     assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
     assert 0 < temperature < 2, "temperature must be in (0, 2)"
 
+    use_prompt = prompt_text is not None and prompt_tokens is not None
+    if use_prompt and isinstance(prompt_text, str):
+        prompt_text = [prompt_text]
+        prompt_tokens = [prompt_tokens]
+
+    assert use_prompt is False or len(prompt_text) == len(
+        prompt_tokens
+    ), "Prompt text and tokens must have the same length"
+
     model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
     im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
 
-    use_prompt = prompt_text is not None and prompt_tokens is not None
     encoded = []
     texts = split_text(text, chunk_length) if iterative_prompt else [text]
+    encoded_prompts = []
 
     if use_prompt:
-        encoded_prompts = encode_tokens(
-            tokenizer,
-            prompt_text,
-            prompt_tokens=prompt_tokens,
-            bos=True,
-            device=device,
-            speaker=speaker,
-            num_codebooks=model.config.num_codebooks,
-        )
+        for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
+            encoded_prompts.append(
+                encode_tokens(
+                    tokenizer,
+                    string=t,
+                    bos=idx == 0,
+                    device=device,
+                    prompt_tokens=c,
+                    speaker=speaker,
+                    num_codebooks=model.config.num_codebooks,
+                )
+            )
 
     for idx, text in enumerate(texts):
         encoded.append(
@@ -507,7 +519,9 @@ def generate_long(
             count = 0
             for i, length in enumerate(lengths):
                 count += length
-                if count + length > max_length - 1024:
+                if count + length > max_length - 1024 - sum(
+                    t.shape[1] for t in encoded_prompts
+                ):
                     break
 
             if i != 0 and i % 2 == 0:
@@ -520,7 +534,7 @@ def generate_long(
                 partial_encoded = global_encoded
 
             if use_prompt:
-                partial_encoded = [encoded_prompts] + partial_encoded
+                partial_encoded = encoded_prompts + partial_encoded
 
             cat_encoded = torch.cat(partial_encoded, dim=1)
             prompt_length = cat_encoded.size(1)
@@ -643,9 +657,12 @@ def launch_thread_safe_queue(
     type=str,
     default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
 )
-@click.option("--prompt-text", type=str, default=None)
+@click.option("--prompt-text", type=str, default=None, multiple=True)
 @click.option(
-    "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
+    "--prompt-tokens",
+    type=click.Path(path_type=Path, exists=True),
+    default=None,
+    multiple=True,
 )
 @click.option("--num-samples", type=int, default=1)
 @click.option("--max-new-tokens", type=int, default=0)
@@ -665,11 +682,11 @@ def launch_thread_safe_queue(
 @click.option("--half/--no-half", default=False)
 @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
 @click.option("--max-length", type=int, default=2048)
-@click.option("--chunk-length", type=int, default=30)
+@click.option("--chunk-length", type=int, default=150)
 def main(
     text: str,
-    prompt_text: Optional[str],
-    prompt_tokens: Optional[Path],
+    prompt_text: Optional[list[str]],
+    prompt_tokens: Optional[list[Path]],
     num_samples: int,
     max_new_tokens: int,
     top_p: int,
@@ -690,6 +707,11 @@ def main(
 
     precision = torch.half if half else torch.bfloat16
 
+    if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
+        raise ValueError(
+            f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
+        )
+
     logger.info("Loading model ...")
     t0 = time.time()
     model, decode_one_token = load_model(
@@ -701,11 +723,8 @@ def main(
 
     logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
 
-    prompt_tokens = (
-        torch.from_numpy(np.load(prompt_tokens)).to(device)
-        if prompt_tokens is not None
-        else None
-    )
+    if prompt_tokens is not None:
+        prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
 
     tokenizer = AutoTokenizer.from_pretrained(tokenizer)
     torch.manual_seed(seed)