| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756 |
- import os
- import queue
- import threading
- import time
- from pathlib import Path
- from typing import Optional, Tuple, Union
- import click
- import hydra
- import numpy as np
- import torch
- import torch._dynamo.config
- import torch._inductor.config
- from hydra import compose, initialize
- from hydra.utils import instantiate
- from loguru import logger
- from tqdm import tqdm
- from transformers import AutoTokenizer
- from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
- from fish_speech.text.clean import clean_text
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
- torch._inductor.config.coordinate_descent_tuning = True
- torch._inductor.config.triton.unique_kernel_names = True
- if hasattr(torch._inductor.config, "fx_graph_cache"):
- # Experimental feature to reduce compilation times, will be on by default in future
- torch._inductor.config.fx_graph_cache = True
- from fish_speech.models.text2semantic.llama import DualARTransformer, NaiveTransformer
- def multinomial_sample_one_no_sync(
- probs_sort,
- ): # Does multinomial sampling without a cuda synchronization
- q = torch.empty_like(probs_sort).exponential_(1)
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
- def logits_to_probs(
- logits,
- previous_tokens: Optional[torch.Tensor] = None,
- temperature: torch.Tensor = 1.0,
- top_p: torch.Tensor = 1.0,
- repetition_penalty: torch.Tensor = 1.0,
- ) -> torch.Tensor:
- # Apply repetition penalty
- if previous_tokens is not None:
- previous_tokens = previous_tokens.long()
- score = torch.gather(logits, dim=0, index=previous_tokens)
- score = torch.where(
- score < 0, score * repetition_penalty, score / repetition_penalty
- )
- logits.scatter_(dim=0, index=previous_tokens, src=score)
- # Apply top-p sampling
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
- cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
- sorted_indices_to_remove = cum_probs > top_p
- sorted_indices_to_remove[0] = False # keep at least one option
- indices_to_remove = sorted_indices_to_remove.scatter(
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
- )
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
- logits = logits / max(temperature, 1e-5)
- probs = torch.nn.functional.softmax(logits, dim=-1)
- return probs
- def sample(
- logits,
- previous_tokens: Optional[torch.Tensor] = None,
- **sampling_kwargs,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- probs = logits_to_probs(
- logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
- )
- idx_next = multinomial_sample_one_no_sync(probs)
- return idx_next, probs
- def decode_one_token_ar(
- model: DualARTransformer,
- x: torch.Tensor,
- input_pos: torch.Tensor,
- previous_tokens: torch.Tensor = None,
- **sampling_kwargs,
- ) -> torch.Tensor:
- x = model.forward_generate(x, input_pos)
- codebooks = [
- sample(
- x.logits,
- previous_tokens=None, # Disable repetition penalty for the token codebook
- **sampling_kwargs,
- )[0]
- ]
- x = x.hidden_states
- # Cleanup the cache
- for layer in model.fast_layers:
- layer.attention.kv_cache.k_cache.fill_(0)
- layer.attention.kv_cache.v_cache.fill_(0)
- for codebook_idx in range(model.config.num_codebooks):
- input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
- logits = model.forward_generate_fast(x, input_pos)
- a = sample(
- logits,
- previous_tokens=(
- previous_tokens[codebook_idx + 1]
- if previous_tokens is not None
- else None
- ),
- **sampling_kwargs,
- )[0]
- x = model.fast_embeddings(a)
- codebooks.append(a)
- return torch.stack(codebooks, dim=0)
- def decode_one_token_naive(
- model: NaiveTransformer,
- x: torch.Tensor,
- input_pos: torch.Tensor,
- previous_tokens: torch.Tensor = None,
- **sampling_kwargs,
- ) -> torch.Tensor:
- x = model.forward_generate(x, input_pos)
- codebooks = [
- sample(
- x.token_logits,
- previous_tokens=None, # Disable repetition penalty for the token codebook
- **sampling_kwargs,
- )[0]
- ]
- for i in range(model.config.num_codebooks):
- codebooks.append(
- sample(
- x.codebook_logits[:, :, i],
- previous_tokens=(
- previous_tokens[i + 1] if previous_tokens is not None else None
- ),
- **sampling_kwargs,
- )[0]
- )
- return torch.stack(codebooks, dim=0)
- def decode_n_tokens(
- model: NaiveTransformer,
- cur_token: torch.Tensor,
- input_pos: torch.Tensor,
- num_new_tokens: int,
- eos_token_id: int = 2,
- im_end_id: int = 4,
- decode_one_token=decode_one_token_naive,
- **sampling_kwargs,
- ):
- previous_tokens = torch.zeros(
- (model.config.num_codebooks + 1, model.config.max_seq_len),
- dtype=torch.int,
- device=cur_token.device,
- )
- for i in tqdm(range(num_new_tokens)):
- # We need to get windowed repeat penalty
- win_size = 16
- if i < win_size:
- window = previous_tokens[:, :win_size]
- else:
- window = previous_tokens[:, i - win_size : i]
- with torch.backends.cuda.sdp_kernel(
- enable_flash=False, enable_mem_efficient=False, enable_math=True
- ): # Actually better for Inductor to codegen attention here
- next_token = decode_one_token(
- model=model,
- x=cur_token,
- input_pos=input_pos,
- previous_tokens=window,
- **sampling_kwargs,
- )
- input_pos += 1
- cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
- previous_tokens[:, i : i + 1] = next_token.view(
- model.config.num_codebooks + 1, -1
- )
- if (
- cur_token[0, 0, -1] == eos_token_id
- or cur_token[0, 0, -1] == im_end_id
- or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
- ):
- break
- return previous_tokens[:, : i + 1]
- @torch.no_grad()
- @torch.inference_mode()
- def generate(
- *,
- model: NaiveTransformer,
- prompt: torch.Tensor,
- max_new_tokens: int,
- eos_token_id: int = 2,
- im_end_id: int = 4,
- decode_one_token=decode_one_token_naive,
- **sampling_kwargs,
- ) -> torch.Tensor:
- """
- Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
- """
- # create an empty tensor of the expected final shape and fill in the current tokens
- T = prompt.size(1)
- if max_new_tokens:
- if T + max_new_tokens > model.config.max_seq_len:
- max_new_tokens = model.config.max_seq_len - T
- logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
- T_new = T + max_new_tokens
- else:
- T_new = model.config.max_seq_len
- max_new_tokens = T_new - T
- device, dtype = prompt.device, prompt.dtype
- with torch.device(device):
- model.setup_caches(
- max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
- )
- codebook_dim = 1 + model.config.num_codebooks
- # create an empty tensor of the expected final shape and fill in the current tokens
- empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
- empty[:, :T] = prompt
- seq = empty
- input_pos = torch.arange(0, T, device=device)
- # Use non-accelerated version for now, to avoid compilation overhead
- prefill_decode = (
- decode_one_token_naive
- if isinstance(model, NaiveTransformer)
- else decode_one_token_ar
- )
- next_token = prefill_decode(
- model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
- )
- seq[:, T : T + 1] = next_token
- input_pos = torch.tensor([T], device=device, dtype=torch.int)
- x = decode_n_tokens(
- model,
- next_token.view(1, codebook_dim, -1),
- input_pos,
- max_new_tokens - 1,
- eos_token_id=eos_token_id,
- im_end_id=im_end_id,
- decode_one_token=decode_one_token,
- **sampling_kwargs,
- )
- # x = torch.cat(generated_tokens, dim=1)
- seq = seq[:, : T + 1 + x.size(1)]
- seq[:, T + 1 :] = x
- return seq
- def encode_tokens(
- tokenizer,
- string,
- bos=True,
- device="cuda",
- prompt_tokens=None,
- speaker=None,
- num_codebooks=4,
- ):
- string = clean_text(string)
- if speaker is None:
- speaker = "assistant"
- string = (
- f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
- )
- if bos:
- string = f"<|begin_of_sequence|>{string}"
- new_tokens = tokenizer.encode(
- string,
- add_special_tokens=False,
- max_length=10**6,
- truncation=False,
- )
- tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
- # Codebooks
- zeros = (
- torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
- * CODEBOOK_PAD_TOKEN_ID
- )
- prompt = torch.cat((tokens, zeros), dim=0)
- if prompt_tokens is None:
- return prompt
- # Get prompt tokens
- if prompt_tokens.ndim == 3:
- assert (
- prompt_tokens.shape[0] == 1
- ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
- prompt_tokens = prompt_tokens[0]
- assert prompt_tokens.ndim == 2
- data = prompt_tokens + 2
- if prompt_tokens.shape[0] > num_codebooks:
- logger.warning(
- f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
- )
- data = data[:num_codebooks]
- # Add eos token for each codebook
- data = torch.cat(
- (
- data,
- torch.ones((data.size(0), 1), dtype=torch.int, device=device)
- * CODEBOOK_EOS_TOKEN_ID,
- ),
- dim=1,
- )
- # Since 1.0, we use <|semantic|>
- s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
- end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
- main_token_ids = (
- torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
- )
- main_token_ids[0, -1] = end_token_id
- data = torch.cat((main_token_ids, data), dim=0)
- prompt = torch.cat((prompt, data), dim=1)
- return prompt
- def load_model(
- config_name, checkpoint_path, device, precision, max_length, compile=False
- ):
- hydra.core.global_hydra.GlobalHydra.instance().clear()
- with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
- cfg = compose(
- config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
- )
- model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
- if "int8" in str(checkpoint_path):
- logger.info("Using int8 weight-only quantization!")
- from quantize import WeightOnlyInt8QuantHandler
- simple_quantizer = WeightOnlyInt8QuantHandler(model)
- model = simple_quantizer.convert_for_runtime()
- if "int4" in str(checkpoint_path):
- logger.info("Using int4 quantization!")
- path_comps = checkpoint_path.name.split(".")
- assert path_comps[-2].startswith("g")
- groupsize = int(path_comps[-2][1:])
- from quantize import WeightOnlyInt4QuantHandler
- simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
- model = simple_quantizer.convert_for_runtime()
- checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
- if "state_dict" in checkpoint:
- checkpoint = checkpoint["state_dict"]
- if any(k.startswith("model.") for k in checkpoint):
- checkpoint = {
- k.replace("model.", ""): v
- for k, v in checkpoint.items()
- if k.startswith("model.")
- }
- model.load_state_dict(checkpoint, assign=True)
- model = model.to(device=device, dtype=precision)
- logger.info("Restored model from checkpoint")
- if isinstance(model, DualARTransformer):
- decode_one_token = decode_one_token_ar
- logger.info("Using DualARTransformer")
- else:
- decode_one_token = decode_one_token_naive
- logger.info("Using NaiveTransformer")
- if compile:
- logger.info("Compiling function...")
- decode_one_token = torch.compile(
- decode_one_token, mode="reduce-overhead", fullgraph=True
- )
- return model.eval(), decode_one_token
- def split_text(text, min_length):
- text = clean_text(text)
- segments = []
- curr = ""
- for char in text:
- curr += char
- if char not in [".", "!", "?"]:
- continue
- if len(curr) >= min_length:
- segments.append(curr)
- curr = ""
- if curr:
- segments.append(curr)
- return segments
- def generate_long(
- *,
- model,
- tokenizer: callable,
- device: str | torch.device,
- decode_one_token: callable,
- text: str,
- num_samples: int = 1,
- max_new_tokens: int = 0,
- top_p: int = 0.7,
- repetition_penalty: float = 1.5,
- temperature: float = 0.7,
- compile: bool = False,
- iterative_prompt: bool = True,
- max_length: int = 2048,
- chunk_length: int = 30,
- speaker: Optional[str] = None,
- prompt_text: Optional[str] = None,
- prompt_tokens: Optional[torch.Tensor] = None,
- is_streaming: bool = False,
- ):
- assert 0 < top_p <= 1, "top_p must be in (0, 1]"
- assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
- assert 0 < temperature < 2, "temperature must be in (0, 2)"
- model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
- im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
- use_prompt = prompt_text is not None and prompt_tokens is not None
- encoded = []
- texts = split_text(text, chunk_length) if iterative_prompt else [text]
- if use_prompt:
- encoded_prompts = encode_tokens(
- tokenizer,
- prompt_text,
- prompt_tokens=prompt_tokens,
- bos=True,
- device=device,
- speaker=speaker,
- num_codebooks=model.config.num_codebooks,
- )
- for idx, text in enumerate(texts):
- encoded.append(
- encode_tokens(
- tokenizer,
- string=text,
- bos=idx == 0 and not use_prompt,
- device=device,
- speaker=speaker,
- num_codebooks=model.config.num_codebooks,
- )
- )
- logger.info(f"Encoded text: {text}")
- # Move temperature, top_p, repetition_penalty to device
- # This is important so that changing params doesn't trigger recompile
- temperature = torch.tensor(temperature, device=device, dtype=torch.float)
- top_p = torch.tensor(top_p, device=device, dtype=torch.float)
- repetition_penalty = torch.tensor(
- repetition_penalty, device=device, dtype=torch.float
- )
- for sample_idx in range(num_samples):
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- global_encoded = []
- all_codes = []
- seg_idx = 0
- while seg_idx < len(encoded):
- logger.info(
- f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
- )
- seg = encoded[seg_idx]
- global_encoded.append(seg)
- lengths = reversed([seg.size(1) for seg in global_encoded])
- # Pick last 2000 tokens
- count = 0
- for i, length in enumerate(lengths):
- count += length
- if count + length > max_length - 1024:
- break
- if i != 0 and i % 2 == 0:
- i -= 1
- # Rotate the list, always make sure first segment is included to avoid drift
- if i < len(global_encoded) - 2:
- partial_encoded = global_encoded[:2] + global_encoded[-i:]
- else:
- partial_encoded = global_encoded
- if use_prompt:
- partial_encoded = [encoded_prompts] + partial_encoded
- cat_encoded = torch.cat(partial_encoded, dim=1)
- prompt_length = cat_encoded.size(1)
- t0 = time.perf_counter()
- y = generate(
- model=model,
- prompt=cat_encoded,
- max_new_tokens=max_new_tokens,
- eos_token_id=tokenizer.eos_token_id,
- im_end_id=im_end_id,
- decode_one_token=decode_one_token,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- )
- if sample_idx == 0 and seg_idx == 0 and compile:
- logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- t = time.perf_counter() - t0
- tokens_generated = y.size(1) - prompt_length
- tokens_sec = tokens_generated / t
- logger.info(
- f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
- )
- logger.info(
- f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
- )
- if torch.cuda.is_available():
- logger.info(
- f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
- )
- # Put the generated tokens
- # since there is <im_end> and <eos> tokens, we remove last 2 tokens
- codes = y[1:, prompt_length:-2].clone()
- codes = codes - 2
- assert (codes >= 0).all(), f"Negative code found"
- decoded = y[:, prompt_length:-1].clone()
- if decoded[0, -1] != im_end_id: # <im_end>
- val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
- decoded = torch.cat(
- (decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
- )
- # But for global encoding, we should keep the <im_end> token
- global_encoded.append(decoded)
- if is_streaming:
- assert (codes >= 0).all(), f"Negative code found: {codes}"
- yield codes
- else:
- all_codes.append(codes)
- seg_idx += 1
- if is_streaming:
- # This indicates the end of the current sample
- yield "next"
- else:
- all_codes = torch.cat(all_codes, dim=1)
- assert (all_codes >= 0).all(), f"Negative code found: {codes}"
- yield all_codes
- def launch_thread_safe_queue(
- config_name,
- checkpoint_path,
- device,
- precision,
- max_length,
- compile=False,
- ):
- input_queue = queue.Queue()
- init_event = threading.Event()
- def worker():
- model, decode_one_token = load_model(
- config_name, checkpoint_path, device, precision, max_length, compile=compile
- )
- init_event.set()
- while True:
- item = input_queue.get()
- if item is None:
- break
- kwargs = item["request"]
- response_queue = item["response_queue"]
- try:
- item["success"] = True
- for chunk in generate_long(
- model=model, decode_one_token=decode_one_token, **kwargs
- ):
- response_queue.put(chunk)
- response_queue.put("done")
- except Exception as e:
- item["success"] = False
- item["response"] = e
- response_queue.put("done")
- threading.Thread(target=worker, daemon=True).start()
- init_event.wait()
- return input_queue
- @click.command()
- @click.option(
- "--text",
- type=str,
- default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
- )
- @click.option("--prompt-text", type=str, default=None)
- @click.option(
- "--prompt-tokens", type=click.Path(path_type=Path, exists=True), default=None
- )
- @click.option("--num-samples", type=int, default=1)
- @click.option("--max-new-tokens", type=int, default=0)
- @click.option("--top-p", type=float, default=0.7)
- @click.option("--repetition-penalty", type=float, default=1.5)
- @click.option("--temperature", type=float, default=0.7)
- @click.option(
- "--checkpoint-path",
- type=click.Path(path_type=Path, exists=True),
- default="results/text2semantic_400m_finetune/step_000002000.pth",
- )
- @click.option("--config-name", type=str, default="dual_ar_8_codebook_small")
- @click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
- @click.option("--compile/--no-compile", default=False)
- @click.option("--seed", type=int, default=42)
- @click.option("--speaker", type=str, default=None)
- @click.option("--half/--no-half", default=False)
- @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
- @click.option("--max-length", type=int, default=2048)
- @click.option("--chunk-length", type=int, default=30)
- def main(
- text: str,
- prompt_text: Optional[str],
- prompt_tokens: Optional[Path],
- num_samples: int,
- max_new_tokens: int,
- top_p: int,
- repetition_penalty: float,
- temperature: float,
- checkpoint_path: Path,
- config_name: str,
- tokenizer: str,
- compile: bool,
- seed: int,
- speaker: Optional[str],
- half: bool,
- iterative_prompt: bool,
- max_length: int,
- chunk_length: int,
- ) -> None:
- device = "cuda"
- precision = torch.half if half else torch.bfloat16
- logger.info("Loading model ...")
- t0 = time.time()
- model, decode_one_token = load_model(
- config_name, checkpoint_path, device, precision, max_length, compile=compile
- )
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
- prompt_tokens = (
- torch.from_numpy(np.load(prompt_tokens)).to(device)
- if prompt_tokens is not None
- else None
- )
- tokenizer = AutoTokenizer.from_pretrained(tokenizer)
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- generator = generate_long(
- model=model,
- device=device,
- decode_one_token=decode_one_token,
- text=text,
- num_samples=num_samples,
- max_new_tokens=max_new_tokens,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- temperature=temperature,
- tokenizer=tokenizer,
- compile=compile,
- speaker=speaker,
- iterative_prompt=iterative_prompt,
- max_length=max_length,
- chunk_length=chunk_length,
- prompt_text=prompt_text,
- prompt_tokens=prompt_tokens,
- )
- for idx, codes in enumerate(generator):
- np.save(f"codes_{idx}.npy", codes.cpu().numpy())
- logger.info(f"Saved codes to codes_{idx}.npy")
- if __name__ == "__main__":
- main()
|