|
|
@@ -402,6 +402,7 @@ def load_model(config_name, checkpoint_path, device, precision):
|
|
|
@click.option("--speaker", type=str, default=None)
|
|
|
@click.option("--order", type=str, default="zh,jp,en")
|
|
|
@click.option("--half/--no-half", default=False)
|
|
|
+@click.option("--iterative-prompt/--no-iterative-prompt", default=False)
|
|
|
def main(
|
|
|
text: str,
|
|
|
prompt_text: Optional[str],
|
|
|
@@ -421,6 +422,7 @@ def main(
|
|
|
speaker: Optional[str],
|
|
|
order: str,
|
|
|
half: bool,
|
|
|
+ iterative_prompt: bool,
|
|
|
) -> None:
|
|
|
device = "cuda"
|
|
|
|
|
|
@@ -443,18 +445,17 @@ def main(
|
|
|
|
|
|
use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
|
|
|
|
- encoded = encode_tokens(
|
|
|
- tokenizer,
|
|
|
- text,
|
|
|
- bos=False if use_prompt else True,
|
|
|
- device=device,
|
|
|
- use_g2p=use_g2p,
|
|
|
- speaker=None if use_prompt else speaker,
|
|
|
- order=order,
|
|
|
- num_codebooks=model.config.num_codebooks,
|
|
|
- )
|
|
|
-
|
|
|
- if use_prompt:
|
|
|
+ if use_prompt and iterative_prompt:
|
|
|
+ encoded = encode_tokens(
|
|
|
+ tokenizer,
|
|
|
+ text,
|
|
|
+ bos=False,
|
|
|
+ device=device,
|
|
|
+ use_g2p=use_g2p,
|
|
|
+ speaker=None,
|
|
|
+ order=order,
|
|
|
+ num_codebooks=model.config.num_codebooks,
|
|
|
+ )
|
|
|
encoded_prompt = encode_tokens(
|
|
|
tokenizer,
|
|
|
prompt_text,
|
|
|
@@ -467,6 +468,21 @@ def main(
|
|
|
num_codebooks=model.config.num_codebooks,
|
|
|
)
|
|
|
encoded = torch.cat((encoded_prompt, encoded), dim=1)
|
|
|
+ else:
|
|
|
+ if prompt_text:
|
|
|
+ text = prompt_text + text
|
|
|
+
|
|
|
+ encoded = encode_tokens(
|
|
|
+ tokenizer,
|
|
|
+ text,
|
|
|
+ bos=True,
|
|
|
+ device=device,
|
|
|
+ use_g2p=use_g2p,
|
|
|
+ speaker=speaker,
|
|
|
+ order=order,
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ num_codebooks=model.config.num_codebooks,
|
|
|
+ )
|
|
|
|
|
|
prompt_length = encoded.size(1)
|
|
|
logger.info(f"Encoded prompt shape: {encoded.shape}")
|