import os
import time
from pathlib import Path
from typing import Optional, Tuple
import click
import numpy as np
import torch
import torch._dynamo.config
import torch._inductor.config
from hydra import compose, initialize
from hydra.utils import instantiate
from loguru import logger
from tqdm import tqdm
from transformers import AutoTokenizer
from fish_speech.text.parser import clean_text
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
if hasattr(torch._inductor.config, "fx_graph_cache"):
# Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.fx_graph_cache = True
from fish_speech.models.text2semantic.llama import Transformer
from fish_speech.text import g2p
from fish_speech.text.symbols import pad as pad_symbol
from fish_speech.text.symbols import pu_symbols
def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(
logits,
previous_tokens: Optional[torch.Tensor] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(
logits,
previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs(
logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def decode_one_token(
model: Transformer,
x: torch.Tensor,
input_pos: torch.Tensor,
previous_tokens: torch.Tensor = None,
**sampling_kwargs,
) -> torch.Tensor:
assert input_pos.shape[-1] == 1
logits = model.forward_generate(x, input_pos)
codebooks = [
sample(
logits.token_logits,
previous_tokens=None, # Disable repetition penalty for the token codebook
**sampling_kwargs,
)[0]
]
# Disable and tokens for codebooks
if model.config.num_codebooks != 0:
logits.codebook_logits[:, :, :, :1] = -float("Inf")
for i in range(model.config.num_codebooks):
codebooks.append(
sample(
logits.codebook_logits[:, :, i],
previous_tokens=(
previous_tokens[i + 1] if previous_tokens is not None else None
),
**sampling_kwargs,
)[0]
)
return torch.stack(codebooks, dim=0)
def prefill(
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> torch.Tensor:
# input_pos: [B, S]
logits = model.forward_generate(x, input_pos)
codebooks = [
sample(
logits.token_logits,
previous_tokens=None,
**sampling_kwargs,
)[0]
]
# Disable and tokens for codebooks
if model.config.num_codebooks != 0:
logits.codebook_logits[:, :, :, :2] = -float("Inf")
for i in range(model.config.num_codebooks):
codebooks.append(
sample(
logits.codebook_logits[:, :, i],
previous_tokens=None,
**sampling_kwargs,
)[0]
)
return torch.stack(codebooks, dim=0)
def decode_n_tokens(
model: Transformer,
cur_token: torch.Tensor,
input_pos: torch.Tensor,
num_new_tokens: int,
eos_token_id: int = 2,
**sampling_kwargs,
):
previous_tokens = torch.zeros(
(model.config.num_codebooks + 1, model.config.max_seq_len),
dtype=torch.int,
device=cur_token.device,
)
for i in tqdm(range(num_new_tokens)):
# We need to get windowed repeat penalty
win_size = 16
if i < win_size:
window = previous_tokens[:, :win_size]
else:
window = previous_tokens[:, i - win_size : i]
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=False, enable_math=True
): # Actually better for Inductor to codegen attention here
next_token = decode_one_token(
model=model,
x=cur_token,
input_pos=input_pos,
previous_tokens=window,
**sampling_kwargs,
)
input_pos += 1
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
previous_tokens[:, i : i + 1] = next_token.view(
model.config.num_codebooks + 1, -1
)
# TODO: use tokenizer's eos
if cur_token[0, 0, -1] == eos_token_id or (cur_token[0, 1:, -1] == 1).any():
break
return previous_tokens[:, : i + 1]
@torch.no_grad()
def generate(
*,
model: Transformer,
prompt: torch.Tensor,
max_new_tokens: int,
eos_token_id: int = 2,
precision: torch.dtype = torch.bfloat16,
**sampling_kwargs,
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(1)
if max_new_tokens:
if T + max_new_tokens > model.config.max_seq_len:
max_new_tokens = model.config.max_seq_len - T
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
T_new = T + max_new_tokens
else:
T_new = model.config.max_seq_len
max_new_tokens = T_new - T
device, dtype = prompt.device, prompt.dtype
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_len=T_new, dtype=precision)
codebook_dim = 1 + model.config.num_codebooks
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
empty[:, :T] = prompt
seq = empty
input_pos = torch.arange(0, T, device=device)
next_token = prefill(
model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
)
seq[:, T : T + 1] = next_token
input_pos = torch.tensor([T], device=device, dtype=torch.int)
x = decode_n_tokens(
model,
next_token.view(1, codebook_dim, -1),
input_pos,
max_new_tokens - 1,
eos_token_id=eos_token_id,
**sampling_kwargs,
)
# x = torch.cat(generated_tokens, dim=1)
seq = seq[:, : T + 1 + x.size(1)]
seq[:, T + 1 :] = x
return seq
def encode_tokens(
tokenizer,
string,
bos=True,
device="cuda",
prompt_text=None,
prompt_tokens=None,
use_g2p=False,
speaker=None,
order="zh,jp,en",
):
if prompt_text is not None:
string = prompt_text + " " + string
if use_g2p:
order = order.split(",")
prompt = g2p(string, order=order)
prompt = [
(f"" if i not in pu_symbols and i != pad_symbol else i)
for _, i in prompt
]
string = " ".join(prompt)
else:
string = clean_text(string)
if speaker is not None:
string = f"[SPK: {speaker}] {string}"
string = f"[INST] {string} [/INST]"
tokens = tokenizer.encode(
string,
max_length=10**6,
add_special_tokens=bos,
truncation=False,
)
tokens = torch.tensor([tokens], dtype=torch.int, device=device)
# Codebooks
zeros = torch.zeros((4, tokens.size(1)), dtype=torch.int, device=device)
prompt = torch.cat((tokens, zeros), dim=0)
if prompt_tokens is None:
return prompt
# Get prompt tokens
assert prompt_tokens.ndim == 2
data = prompt_tokens + 2
zeros = (
torch.zeros((1, data.size(1)), dtype=torch.int, device=device)
+ tokenizer.pad_token_id
) # 32311 is the token
data = torch.cat((zeros, data), dim=0)
prompt = torch.cat((prompt, data), dim=1)
return prompt
def load_model(config_name, checkpoint_path, device, precision):
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
cfg = compose(config_name=config_name)
with torch.device("meta"):
model: Transformer = instantiate(cfg.model).model
if "int8" in str(checkpoint_path):
logger.info("Using int8 weight-only quantization!")
from quantize import WeightOnlyInt8QuantHandler
simple_quantizer = WeightOnlyInt8QuantHandler(model)
model = simple_quantizer.convert_for_runtime()
if "int4" in str(checkpoint_path):
logger.info("Using int4 quantization!")
path_comps = checkpoint_path.name.split(".")
assert path_comps[-2].startswith("g")
groupsize = int(path_comps[-2][1:])
from quantize import WeightOnlyInt4QuantHandler
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
model = simple_quantizer.convert_for_runtime()
checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
if any(k.startswith("model.") for k in checkpoint):
checkpoint = {
k.replace("model.", ""): v
for k, v in checkpoint.items()
if k.startswith("model.")
}
model.load_state_dict(checkpoint, assign=True)
model = model.to(device=device, dtype=precision)
logger.info("Restored model from checkpoint")
return model.eval()
@click.command()
@click.option(
"--text",
type=str,
default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
)
@click.option("--prompt-text", type=str, default=None)
@click.option(
"--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
)
@click.option("--num-samples", type=int, default=1)
@click.option("--max_new_tokens", type=int, default=0)
@click.option("--top-k", type=int, default=None)
@click.option("--top-p", type=float, default=0.5)
@click.option("--repetition-penalty", type=float, default=1.5)
@click.option("--temperature", type=float, default=0.7)
@click.option(
"--checkpoint-path",
type=click.Path(path_type=Path, exists=True),
default="results/text2semantic_400m_finetune/step_000002000.pth",
)
@click.option("--config-name", type=str, default="text2semantic_finetune")
@click.option("--tokenizer", type=str, default="fishaudio/speech-lm-v1")
@click.option("--compile/--no-compile", default=False)
@click.option("--use-g2p/--no-g2p", default=True)
@click.option("--seed", type=int, default=42)
@click.option("--speaker", type=str, default=None)
@click.option("--order", type=str, default="zh,jp,en")
@click.option("--half/--no-half", default=False)
def main(
text: str,
prompt_text: Optional[str],
prompt_tokens: Optional[Path],
num_samples: int,
max_new_tokens: int,
top_k: int,
top_p: int,
repetition_penalty: float,
temperature: float,
checkpoint_path: Path,
config_name: str,
tokenizer: str,
compile: bool,
use_g2p: bool,
seed: int,
speaker: Optional[str],
order: str,
half: bool,
) -> None:
device = "cuda"
precision = torch.half if half else torch.bfloat16
logger.info("Loading model ...")
t0 = time.time()
model = load_model(config_name, checkpoint_path, device, precision)
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
torch.cuda.synchronize()
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
prompt_tokens = (
torch.from_numpy(np.load(prompt_tokens)).to(device)
if prompt_tokens is not None
else None
)
encoded = encode_tokens(
tokenizer,
text,
prompt_text=prompt_text,
prompt_tokens=prompt_tokens,
bos=True,
device=device,
use_g2p=use_g2p,
speaker=speaker,
order=order,
)
prompt_length = encoded.size(1)
logger.info(f"Encoded prompt shape: {encoded.shape}")
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if compile:
global decode_one_token
decode_one_token = torch.compile(
decode_one_token, mode="reduce-overhead", fullgraph=True
)
for i in range(num_samples):
torch.cuda.synchronize()
t0 = time.perf_counter()
y = generate(
model=model,
prompt=encoded,
max_new_tokens=max_new_tokens,
eos_token_id=tokenizer.eos_token_id,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
if i == 0 and compile:
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
torch.cuda.synchronize()
t = time.perf_counter() - t0
tokens_generated = y.size(1) - prompt_length
tokens_sec = tokens_generated / t
logger.info(
f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
)
logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
logger.info(
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
)
codes = y[1:, prompt_length:-1]
codes = codes - 2
assert (codes >= 0).all(), "Codes should be >= 0"
np.save(f"codes_{i}.npy", codes.cpu().numpy())
logger.info(f"Saved codes to codes_{i}.npy")
if __name__ == "__main__":
main()