|
|
@@ -264,16 +264,12 @@ def encode_tokens(
|
|
|
string,
|
|
|
bos=True,
|
|
|
device="cuda",
|
|
|
- prompt_text=None,
|
|
|
prompt_tokens=None,
|
|
|
use_g2p=False,
|
|
|
speaker=None,
|
|
|
order="zh,jp,en",
|
|
|
num_codebooks=4,
|
|
|
):
|
|
|
- if prompt_text is not None:
|
|
|
- string = prompt_text + " " + string
|
|
|
-
|
|
|
if use_g2p:
|
|
|
order = order.split(",")
|
|
|
prompt = g2p(string, order=order)
|
|
|
@@ -306,6 +302,12 @@ def encode_tokens(
|
|
|
return prompt
|
|
|
|
|
|
# Get prompt tokens
|
|
|
+ if prompt_tokens.ndim == 3:
|
|
|
+ assert (
|
|
|
+ prompt_tokens.shape[0] == 1
|
|
|
+ ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
|
|
|
+ prompt_tokens = prompt_tokens[0]
|
|
|
+
|
|
|
assert prompt_tokens.ndim == 2
|
|
|
data = prompt_tokens + 2
|
|
|
|
|
|
@@ -432,18 +434,34 @@ def main(
|
|
|
if prompt_tokens is not None
|
|
|
else None
|
|
|
)
|
|
|
+
|
|
|
+ use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
|
+
|
|
|
encoded = encode_tokens(
|
|
|
tokenizer,
|
|
|
text,
|
|
|
- prompt_text=prompt_text,
|
|
|
- prompt_tokens=prompt_tokens,
|
|
|
- bos=True,
|
|
|
+ bos=False if use_prompt else True,
|
|
|
device=device,
|
|
|
use_g2p=use_g2p,
|
|
|
- speaker=speaker,
|
|
|
+ speaker=None if use_prompt else speaker,
|
|
|
order=order,
|
|
|
num_codebooks=model.config.num_codebooks,
|
|
|
)
|
|
|
+
|
|
|
+ if use_prompt:
|
|
|
+ encoded_prompt = encode_tokens(
|
|
|
+ tokenizer,
|
|
|
+ prompt_text,
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ bos=True,
|
|
|
+ device=device,
|
|
|
+ use_g2p=use_g2p,
|
|
|
+ speaker=speaker,
|
|
|
+ order=order,
|
|
|
+ num_codebooks=model.config.num_codebooks,
|
|
|
+ )
|
|
|
+ encoded = torch.cat((encoded_prompt, encoded), dim=1)
|
|
|
+
|
|
|
prompt_length = encoded.size(1)
|
|
|
logger.info(f"Encoded prompt shape: {encoded.shape}")
|
|
|
|