|
|
@@ -376,7 +376,7 @@ def load_model(config_name, checkpoint_path, device, precision):
|
|
|
model = model.to(device=device, dtype=precision)
|
|
|
logger.info("Restored model from checkpoint")
|
|
|
|
|
|
- return model.eval()
|
|
|
+ return model.eval(), cfg
|
|
|
|
|
|
|
|
|
def split_text(text, min_length):
|
|
|
@@ -451,7 +451,7 @@ def main(
|
|
|
|
|
|
logger.info("Loading model ...")
|
|
|
t0 = time.time()
|
|
|
- model = load_model(config_name, checkpoint_path, device, precision)
|
|
|
+ model, cfg = load_model(config_name, checkpoint_path, device, precision)
|
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
@@ -466,12 +466,13 @@ def main(
|
|
|
|
|
|
use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
|
encoded = []
|
|
|
- for text in split_text(text, 20):
|
|
|
+ texts = split_text(text, 20) if iterative_prompt else [text]
|
|
|
+ for idx, text in enumerate(texts):
|
|
|
encoded.append(
|
|
|
encode_tokens(
|
|
|
tokenizer,
|
|
|
- text,
|
|
|
- bos=False,
|
|
|
+ string=text,
|
|
|
+ bos=idx == 0 and not use_prompt,
|
|
|
device=device,
|
|
|
use_g2p=use_g2p,
|
|
|
speaker=None,
|
|
|
@@ -553,13 +554,16 @@ def main(
|
|
|
|
|
|
# Put the generated tokens
|
|
|
codes = y[1:, prompt_length:-1].clone()
|
|
|
- new_codes = []
|
|
|
- for j, code in enumerate(codes):
|
|
|
- new_codes.append(
|
|
|
- code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
|
|
|
- )
|
|
|
|
|
|
- codes = torch.stack(new_codes, dim=0)
|
|
|
+ if getattr(cfg, "use_delay_pattern", True):
|
|
|
+ new_codes = []
|
|
|
+ for j, code in enumerate(codes):
|
|
|
+ new_codes.append(
|
|
|
+ code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
|
|
|
+ )
|
|
|
+
|
|
|
+ codes = torch.stack(new_codes, dim=0)
|
|
|
+
|
|
|
codes = codes - 2
|
|
|
if not (codes >= 0).all():
|
|
|
global_encoded.pop()
|