Bläddra i källkod

Allow disable interactive mode

Lengyue 2 år sedan
förälder
incheckning
bbb492d7eb
1 ändrade filer med 28 tillägg och 12 borttagningar
  1. 28 12
      tools/llama/generate.py

+ 28 - 12
tools/llama/generate.py

@@ -402,6 +402,7 @@ def load_model(config_name, checkpoint_path, device, precision):
 @click.option("--speaker", type=str, default=None)
 @click.option("--order", type=str, default="zh,jp,en")
 @click.option("--half/--no-half", default=False)
+@click.option("--iterative-prompt/--no-iterative-prompt", default=False)
 def main(
     text: str,
     prompt_text: Optional[str],
@@ -421,6 +422,7 @@ def main(
     speaker: Optional[str],
     order: str,
     half: bool,
+    iterative_prompt: bool,
 ) -> None:
     device = "cuda"
 
@@ -443,18 +445,17 @@ def main(
 
     use_prompt = prompt_text is not None and prompt_tokens is not None
 
-    encoded = encode_tokens(
-        tokenizer,
-        text,
-        bos=False if use_prompt else True,
-        device=device,
-        use_g2p=use_g2p,
-        speaker=None if use_prompt else speaker,
-        order=order,
-        num_codebooks=model.config.num_codebooks,
-    )
-
-    if use_prompt:
+    if use_prompt and iterative_prompt:
+        encoded = encode_tokens(
+            tokenizer,
+            text,
+            bos=False,
+            device=device,
+            use_g2p=use_g2p,
+            speaker=None,
+            order=order,
+            num_codebooks=model.config.num_codebooks,
+        )
         encoded_prompt = encode_tokens(
             tokenizer,
             prompt_text,
@@ -467,6 +468,21 @@ def main(
             num_codebooks=model.config.num_codebooks,
         )
         encoded = torch.cat((encoded_prompt, encoded), dim=1)
+    else:
+        if prompt_text:
+            text = prompt_text + text
+
+        encoded = encode_tokens(
+            tokenizer,
+            text,
+            bos=True,
+            device=device,
+            use_g2p=use_g2p,
+            speaker=speaker,
+            order=order,
+            prompt_tokens=prompt_tokens,
+            num_codebooks=model.config.num_codebooks,
+        )
 
     prompt_length = encoded.size(1)
     logger.info(f"Encoded prompt shape: {encoded.shape}")