|
|
@@ -260,12 +260,12 @@ def encode_tokens(
|
|
|
string,
|
|
|
bos=True,
|
|
|
device="cuda",
|
|
|
- prompt_string=None,
|
|
|
+ prompt_text=None,
|
|
|
prompt_tokens=None,
|
|
|
use_g2p=False,
|
|
|
):
|
|
|
- if prompt_string is not None:
|
|
|
- string = prompt_string + " " + string
|
|
|
+ if prompt_text is not None:
|
|
|
+ string = prompt_text + " " + string
|
|
|
|
|
|
if use_g2p:
|
|
|
prompt = g2p(string)
|
|
|
@@ -353,7 +353,7 @@ def load_model(config_name, checkpoint_path, device, precision):
|
|
|
|
|
|
@click.command()
|
|
|
@click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
|
|
|
-@click.option("--prompt-string", type=str, default=None)
|
|
|
+@click.option("--prompt-text", type=str, default=None)
|
|
|
@click.option(
|
|
|
"--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
|
|
|
)
|
|
|
@@ -375,7 +375,7 @@ def load_model(config_name, checkpoint_path, device, precision):
|
|
|
@click.option("--seed", type=int, default=42)
|
|
|
def main(
|
|
|
text: str,
|
|
|
- prompt_string: Optional[str],
|
|
|
+ prompt_text: Optional[str],
|
|
|
prompt_tokens: Optional[Path],
|
|
|
num_samples: int,
|
|
|
max_new_tokens: int,
|
|
|
@@ -410,7 +410,7 @@ def main(
|
|
|
encoded = encode_tokens(
|
|
|
tokenizer,
|
|
|
text,
|
|
|
- prompt_string=prompt_string,
|
|
|
+ prompt_text=prompt_text,
|
|
|
prompt_tokens=prompt_tokens,
|
|
|
bos=True,
|
|
|
device=device,
|