|
|
@@ -254,18 +254,26 @@ def generate(
|
|
|
|
|
|
|
|
|
def encode_tokens(
|
|
|
- tokenizer, string, bos=True, device="cuda", prompt_string=None, prompt_tokens=None
|
|
|
+ tokenizer,
|
|
|
+ string,
|
|
|
+ bos=True,
|
|
|
+ device="cuda",
|
|
|
+ prompt_string=None,
|
|
|
+ prompt_tokens=None,
|
|
|
+ use_g2p=False,
|
|
|
):
|
|
|
if prompt_string is not None:
|
|
|
string = prompt_string + " " + string
|
|
|
|
|
|
- prompt = g2p(string)
|
|
|
- prompt = [
|
|
|
- (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
|
|
|
- for _, i in prompt
|
|
|
- ]
|
|
|
- prompt = " ".join(prompt)
|
|
|
- string = f"[INST] {prompt} [/INST]"
|
|
|
+ if use_g2p:
|
|
|
+ prompt = g2p(prompt)
|
|
|
+ prompt = [
|
|
|
+ (f"<p:{i}>" if i not in pu_symbols and i != pad_symbol else i)
|
|
|
+ for _, i in prompt
|
|
|
+ ]
|
|
|
+ string = " ".join(prompt)
|
|
|
+
|
|
|
+ string = f"[INST] {string} [/INST]"
|
|
|
|
|
|
tokens = tokenizer.encode(
|
|
|
string,
|
|
|
@@ -359,6 +367,7 @@ def load_model(config_name, checkpoint_path, device, precision):
|
|
|
@click.option("--config-name", type=str, default="text2semantic_finetune")
|
|
|
@click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
|
|
|
@click.option("--compile/--no-compile", default=False)
|
|
|
+@click.option("--use-g2p/--no-g2p", default=True)
|
|
|
@click.option("--seed", type=int, default=42)
|
|
|
def main(
|
|
|
text: str,
|
|
|
@@ -374,6 +383,7 @@ def main(
|
|
|
config_name: str,
|
|
|
tokenizer: str,
|
|
|
compile: bool,
|
|
|
+ use_g2p: bool,
|
|
|
seed: int,
|
|
|
) -> None:
|
|
|
device = "cuda"
|
|
|
@@ -400,6 +410,7 @@ def main(
|
|
|
prompt_tokens=prompt_tokens,
|
|
|
bos=True,
|
|
|
device=device,
|
|
|
+ use_g2p=use_g2p,
|
|
|
)
|
|
|
prompt_length = encoded.size(1)
|
|
|
logger.info(f"Encoded prompt shape: {encoded.shape}")
|