|
@@ -207,6 +207,7 @@ def generate(
|
|
|
prompt: torch.Tensor,
|
|
prompt: torch.Tensor,
|
|
|
max_new_tokens: int,
|
|
max_new_tokens: int,
|
|
|
eos_token_id: int = 2,
|
|
eos_token_id: int = 2,
|
|
|
|
|
+ precision: torch.dtype = torch.bfloat16,
|
|
|
**sampling_kwargs,
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
"""
|
|
@@ -228,7 +229,7 @@ def generate(
|
|
|
|
|
|
|
|
device, dtype = prompt.device, prompt.dtype
|
|
device, dtype = prompt.device, prompt.dtype
|
|
|
with torch.device(device):
|
|
with torch.device(device):
|
|
|
- model.setup_caches(max_batch_size=1, max_seq_len=T_new)
|
|
|
|
|
|
|
+ model.setup_caches(max_batch_size=1, max_seq_len=T_new, dtype=precision)
|
|
|
|
|
|
|
|
codebook_dim = 1 + model.config.num_codebooks
|
|
codebook_dim = 1 + model.config.num_codebooks
|
|
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
@@ -381,6 +382,7 @@ def load_model(config_name, checkpoint_path, device, precision):
|
|
|
@click.option("--use-g2p/--no-g2p", default=True)
|
|
@click.option("--use-g2p/--no-g2p", default=True)
|
|
|
@click.option("--seed", type=int, default=42)
|
|
@click.option("--seed", type=int, default=42)
|
|
|
@click.option("--speaker", type=str, default=None)
|
|
@click.option("--speaker", type=str, default=None)
|
|
|
|
|
+@click.option("--half/--no-half", default=False)
|
|
|
def main(
|
|
def main(
|
|
|
text: str,
|
|
text: str,
|
|
|
prompt_text: Optional[str],
|
|
prompt_text: Optional[str],
|
|
@@ -398,9 +400,11 @@ def main(
|
|
|
use_g2p: bool,
|
|
use_g2p: bool,
|
|
|
seed: int,
|
|
seed: int,
|
|
|
speaker: Optional[str],
|
|
speaker: Optional[str],
|
|
|
|
|
+ half: bool,
|
|
|
) -> None:
|
|
) -> None:
|
|
|
device = "cuda"
|
|
device = "cuda"
|
|
|
- precision = torch.bfloat16
|
|
|
|
|
|
|
+
|
|
|
|
|
+ precision = torch.half if half else torch.bfloat16
|
|
|
|
|
|
|
|
logger.info("Loading model ...")
|
|
logger.info("Loading model ...")
|
|
|
t0 = time.time()
|
|
t0 = time.time()
|
|
@@ -445,6 +449,7 @@ def main(
|
|
|
prompt=encoded,
|
|
prompt=encoded,
|
|
|
max_new_tokens=max_new_tokens,
|
|
max_new_tokens=max_new_tokens,
|
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
|
+ precision=precision,
|
|
|
temperature=temperature,
|
|
temperature=temperature,
|
|
|
top_k=top_k,
|
|
top_k=top_k,
|
|
|
top_p=top_p,
|
|
top_p=top_p,
|