|
|
@@ -269,6 +269,7 @@ def encode_tokens(
|
|
|
use_g2p=False,
|
|
|
speaker=None,
|
|
|
order="zh,jp,en",
|
|
|
+ num_codebooks=4,
|
|
|
):
|
|
|
if prompt_text is not None:
|
|
|
string = prompt_text + " " + string
|
|
|
@@ -298,7 +299,7 @@ def encode_tokens(
|
|
|
tokens = torch.tensor([tokens], dtype=torch.int, device=device)
|
|
|
|
|
|
# Codebooks
|
|
|
- zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
|
|
|
+ zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
|
|
|
prompt = torch.cat((tokens, zeros), dim=0)
|
|
|
|
|
|
if prompt_tokens is None:
|
|
|
@@ -308,11 +309,18 @@ def encode_tokens(
|
|
|
assert prompt_tokens.ndim == 2
|
|
|
data = prompt_tokens + 2
|
|
|
|
|
|
- zeros = (
|
|
|
- torch.zeros((1, data.size(1)), dtype=torch.int, device=device)
|
|
|
- + tokenizer.pad_token_id
|
|
|
- ) # 32311 is the <pad> token
|
|
|
- data = torch.cat((zeros, data), dim=0)
|
|
|
+ if prompt_tokens.shape[0] > num_codebooks:
|
|
|
+ logger.warning(
|
|
|
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
|
|
|
+ )
|
|
|
+ data = data[:num_codebooks]
|
|
|
+
|
|
|
+ # Since 1.0, we use <s:xxx> to replace <semantic>
|
|
|
+ main_tokens = [f"<s:{i}>" for i in data[0]]
|
|
|
+ main_token_ids = tokenizer.convert_tokens_to_ids(main_tokens)
|
|
|
+ main_token_ids = torch.tensor([main_token_ids], dtype=torch.int, device=device)
|
|
|
+
|
|
|
+ data = torch.cat((main_token_ids, data), dim=0)
|
|
|
prompt = torch.cat((prompt, data), dim=1)
|
|
|
|
|
|
return prompt
|
|
|
@@ -434,6 +442,7 @@ def main(
|
|
|
use_g2p=use_g2p,
|
|
|
speaker=speaker,
|
|
|
order=order,
|
|
|
+ num_codebooks=model.config.num_codebooks,
|
|
|
)
|
|
|
prompt_length = encoded.size(1)
|
|
|
logger.info(f"Encoded prompt shape: {encoded.shape}")
|