| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118 |
- import os
- import queue
- import threading
- import time
- from contextlib import nullcontext
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Literal, Optional, Tuple, Union
- import click
- import numpy as np
- import torch
- import torch._dynamo.config
- import torch._inductor.config
- from loguru import logger
- from torch.nn.attention import SDPBackend, sdpa_kernel
- from tqdm import tqdm
- from transformers import AutoTokenizer
- from fish_speech.conversation import (
- CODEBOOK_PAD_TOKEN_ID,
- Conversation,
- Message,
- TextPart,
- VQPart,
- )
- from fish_speech.models.text2semantic.llama import (
- BaseModelArgs,
- BaseTransformer,
- DualARTransformer,
- NaiveTransformer,
- )
- from fish_speech.text import clean_text, split_text
- from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
- torch._inductor.config.coordinate_descent_tuning = True
- torch._inductor.config.triton.unique_kernel_names = True
- if hasattr(torch._inductor.config, "fx_graph_cache"):
- # Experimental feature to reduce compilation times, will be on by default in future
- torch._inductor.config.fx_graph_cache = True
- def multinomial_sample_one_no_sync(
- probs_sort,
- ): # Does multinomial sampling without a cuda synchronization
- q = torch.empty_like(probs_sort).exponential_(1)
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
- def logits_to_probs(
- logits,
- previous_tokens: Optional[torch.Tensor] = None,
- temperature: torch.Tensor = 1.0,
- top_p: torch.Tensor = 1.0,
- repetition_penalty: torch.Tensor = 1.0,
- ) -> torch.Tensor:
- # Apply repetition penalty
- if previous_tokens is not None:
- previous_tokens = previous_tokens.long()
- score = torch.gather(logits, dim=0, index=previous_tokens)
- score = torch.where(
- score < 0, score * repetition_penalty, score / repetition_penalty
- )
- logits.scatter_(dim=0, index=previous_tokens, src=score)
- # Apply top-p sampling
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
- cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
- sorted_indices_to_remove = cum_probs > top_p
- sorted_indices_to_remove[0] = False # keep at least one option
- indices_to_remove = sorted_indices_to_remove.scatter(
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
- )
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
- logits = logits / max(temperature, 1e-5)
- probs = torch.nn.functional.softmax(logits, dim=-1)
- return probs
- def multinomial_sample_one_no_sync_agent(
- probs_sort,
- ): # Does multinomial sampling without a cuda synchronization
- q = torch.empty_like(probs_sort).exponential_(1)
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
- def logits_to_probs_agent(
- logits,
- previous_tokens: Optional[torch.Tensor] = None,
- temperature: torch.Tensor = 1.0,
- top_p: torch.Tensor = 1.0,
- repetition_penalty: torch.Tensor = 1.0,
- ) -> torch.Tensor:
- # Apply repetition penalty
- if previous_tokens is not None:
- previous_tokens = previous_tokens.long()
- score = torch.gather(logits, dim=-1, index=previous_tokens)
- score = torch.where(
- score < 0, score * repetition_penalty, score / repetition_penalty
- )
- logits.scatter_(dim=-1, index=previous_tokens, src=score)
- # Apply top-p sampling
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
- cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
- sorted_indices_to_remove = cum_probs > top_p
- sorted_indices_to_remove[..., 0] = False # keep at least one option
- indices_to_remove = sorted_indices_to_remove.scatter(
- dim=-1, index=sorted_indices, src=sorted_indices_to_remove
- )
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
- logits = logits / max(temperature, 1e-5)
- probs = torch.nn.functional.softmax(logits, dim=-1)
- return probs
- def sample(
- logits,
- previous_tokens: Optional[torch.Tensor] = None,
- **sampling_kwargs,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- probs = logits_to_probs(
- logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
- )
- idx_next = multinomial_sample_one_no_sync(probs)
- return idx_next, probs
- def sample_agent(
- logits,
- previous_tokens: Optional[torch.Tensor] = None,
- **sampling_kwargs,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- probs = logits_to_probs_agent(
- logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
- )
- idx_next = multinomial_sample_one_no_sync_agent(probs)
- return idx_next, probs
- def decode_one_token_ar_agent(
- model: DualARTransformer,
- x: torch.Tensor,
- input_pos: torch.Tensor,
- semantic_ids: list,
- previous_tokens: torch.Tensor = None,
- **sampling_kwargs,
- ) -> torch.Tensor:
- # print(x, input_pos)
- x = model.forward_generate(x, input_pos)
- logits = x.logits # [:, -1:]
- hidden_states = x.hidden_states # [:, -1:]
- sampling_kwargs_main = sampling_kwargs.copy()
- sampling_kwargs_main["temperature"] = 0.1
- sampling_kwargs_main["top_p"] = 0.1
- sampling_kwargs_main["repetition_penalty"] = 1.0
- codebooks = [
- sample_agent(
- logits,
- previous_tokens=None, # Disable repetition penalty for the token codebook
- **sampling_kwargs_main,
- )[0]
- ]
- # Cleanup the cache
- for layer in model.fast_layers:
- layer.attention.kv_cache.k_cache.fill_(0)
- layer.attention.kv_cache.v_cache.fill_(0)
- for codebook_idx in range(model.config.num_codebooks):
- input_pos = torch.tensor(
- [codebook_idx], device=hidden_states.device, dtype=torch.long
- )
- logits = model.forward_generate_fast(hidden_states, input_pos)
- a = sample_agent(
- logits,
- previous_tokens=(
- previous_tokens[:, codebook_idx + 1]
- if previous_tokens is not None
- else None
- ),
- **sampling_kwargs,
- )[0]
- hidden_states = model.fast_embeddings(a)
- codebooks.append(a)
- codebooks = torch.stack(codebooks, dim=1)
- semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
- codebooks[:, 1:, :] = torch.masked_fill(
- codebooks[:, 1:, :],
- ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
- CODEBOOK_PAD_TOKEN_ID,
- )
- return codebooks
- def decode_one_token_naive_agent(
- model: NaiveTransformer,
- x: torch.Tensor,
- input_pos: torch.Tensor,
- semantic_ids: list,
- previous_tokens: torch.Tensor = None,
- **sampling_kwargs,
- ) -> torch.Tensor:
- x = model.forward_generate(x, input_pos)
- codebooks = [
- sample(
- x.token_logits,
- previous_tokens=None, # Disable repetition penalty for the token codebook
- **sampling_kwargs,
- )[0]
- ]
- for i in range(model.config.num_codebooks):
- codebooks.append(
- sample_agent(
- x.codebook_logits[:, :, i],
- previous_tokens=(
- previous_tokens[:, i + 1] if previous_tokens is not None else None
- ),
- **sampling_kwargs,
- )[0]
- )
- codebooks = torch.stack(codebooks, dim=1)
- semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
- codebooks[:, 1:, :] = torch.masked_fill(
- codebooks[:, 1:, :],
- ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
- CODEBOOK_PAD_TOKEN_ID,
- )
- return codebooks
- def decode_one_token_ar(
- model: DualARTransformer,
- x: torch.Tensor,
- input_pos: torch.Tensor,
- semantic_ids: list,
- previous_tokens: torch.Tensor = None,
- **sampling_kwargs,
- ) -> torch.Tensor:
- x = model.forward_generate(x, input_pos)
- sampling_kwargs_main = sampling_kwargs.copy()
- # sampling_kwargs_main["temperature"] = 0.1
- # sampling_kwargs_main["top_p"] = 0.1
- # sampling_kwargs_main["repetition_penalty"] = 1.0
- codebooks = [
- sample(
- x.logits,
- previous_tokens=(
- previous_tokens[0] if previous_tokens is not None else None
- ), # Disable repetition penalty for the token codebook
- **sampling_kwargs_main,
- )[0]
- ]
- hidden_states = x.hidden_states
- # Cleanup the cache
- for layer in model.fast_layers:
- layer.attention.kv_cache.k_cache.fill_(0)
- layer.attention.kv_cache.v_cache.fill_(0)
- input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
- model.forward_generate_fast(hidden_states, input_pos)
- a = codebooks[0] - model.tokenizer.semantic_begin_id
- a[a < 0] = 0
- hidden_states = model.fast_embeddings(a)
- codebooks.append(a)
- for codebook_idx in range(1, model.config.num_codebooks):
- input_pos = torch.tensor(
- [codebook_idx], device=hidden_states.device, dtype=torch.long
- )
- logits = model.forward_generate_fast(hidden_states, input_pos)
- a = sample(
- logits,
- previous_tokens=(
- previous_tokens[codebook_idx + 1]
- if previous_tokens is not None
- else None
- ),
- **sampling_kwargs,
- )[0]
- hidden_states = model.fast_embeddings(a)
- codebooks.append(a)
- codebooks = torch.stack(codebooks, dim=0)
- # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
- # codebooks[1:, :] = torch.masked_fill(
- # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
- # )
- # print(codebooks)
- return codebooks
- def decode_one_token_naive(
- model: NaiveTransformer,
- x: torch.Tensor,
- input_pos: torch.Tensor,
- previous_tokens: torch.Tensor = None,
- **sampling_kwargs,
- ) -> torch.Tensor:
- x = model.forward_generate(x, input_pos)
- sampling_kwargs_main = sampling_kwargs.copy()
- sampling_kwargs_main["temperature"] = 0.1
- sampling_kwargs_main["top_p"] = 0.1
- sampling_kwargs_main["repetition_penalty"] = 1.0
- codebooks = [
- sample(
- x.logits,
- previous_tokens=None, # Disable repetition penalty for the token codebook
- **sampling_kwargs_main,
- )[0]
- ]
- for i in range(model.config.num_codebooks):
- codebooks.append(
- sample(
- x.codebook_logits[:, :, i],
- previous_tokens=(
- previous_tokens[i + 1] if previous_tokens is not None else None
- ),
- **sampling_kwargs,
- )[0]
- )
- return torch.stack(codebooks, dim=0)
- def decode_n_tokens(
- model: NaiveTransformer,
- cur_token: torch.Tensor,
- input_pos: torch.Tensor,
- num_new_tokens: int,
- semantic_ids: list,
- decode_one_token=decode_one_token_naive,
- **sampling_kwargs,
- ):
- previous_tokens = torch.zeros(
- (model.config.num_codebooks + 1, model.config.max_seq_len),
- dtype=torch.int,
- device=cur_token.device,
- )
- for i in tqdm(range(num_new_tokens)):
- # We need to get windowed repeat penalty
- win_size = 16
- if i < win_size:
- window = previous_tokens[:, :win_size]
- else:
- window = previous_tokens[:, i - win_size : i]
- with (
- sdpa_kernel(
- [
- SDPBackend.FLASH_ATTENTION,
- SDPBackend.EFFICIENT_ATTENTION,
- SDPBackend.MATH,
- ]
- )
- if torch.cuda.is_available()
- else nullcontext()
- ): # Actually better for Inductor to codegen attention here
- next_token = decode_one_token(
- model=model,
- x=cur_token,
- input_pos=input_pos,
- previous_tokens=window,
- semantic_ids=semantic_ids,
- **sampling_kwargs,
- )
- input_pos += 1
- cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
- previous_tokens[:, i : i + 1] = next_token.view(
- model.config.num_codebooks + 1, -1
- )
- if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
- break
- return previous_tokens[:, : i + 1]
- @torch.no_grad()
- @torch.inference_mode()
- def generate(
- *,
- model: NaiveTransformer,
- prompt: torch.Tensor,
- max_new_tokens: int,
- decode_one_token=decode_one_token_naive,
- **sampling_kwargs,
- ) -> torch.Tensor:
- """
- Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
- """
- # create an empty tensor of the expected final shape and fill in the current tokens
- T = prompt.size(1)
- # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
- semantic_ids = [
- model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
- ]
- if max_new_tokens:
- if T + max_new_tokens > model.config.max_seq_len:
- max_new_tokens = model.config.max_seq_len - T
- logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
- T_new = T + max_new_tokens
- else:
- T_new = model.config.max_seq_len
- max_new_tokens = T_new - T
- device, dtype = prompt.device, prompt.dtype
- codebook_dim = 1 + model.config.num_codebooks
- # create an empty tensor of the expected final shape and fill in the current tokens
- empty = torch.empty(
- (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
- )
- empty[:, :T] = prompt
- seq = empty
- input_pos = torch.arange(0, T, device=device)
- # Use non-accelerated version for now, to avoid compilation overhead
- prefill_decode = (
- decode_one_token_naive
- if isinstance(model, NaiveTransformer)
- else decode_one_token_ar
- )
- next_token = prefill_decode(
- model,
- prompt.view(1, codebook_dim, -1),
- input_pos,
- semantic_ids=semantic_ids,
- **sampling_kwargs,
- )
- seq[:, T : T + 1] = next_token
- input_pos = torch.tensor([T], device=device, dtype=torch.int)
- x = decode_n_tokens(
- model,
- next_token.view(1, codebook_dim, -1),
- input_pos,
- max_new_tokens - 1,
- decode_one_token=decode_one_token,
- semantic_ids=semantic_ids,
- **sampling_kwargs,
- )
- # x = torch.cat(generated_tokens, dim=1)
- seq = seq[:, : T + 1 + x.size(1)]
- seq[:, T + 1 :] = x
- return seq
- def decode_n_tokens_agent(
- model: NaiveTransformer,
- cur_token: torch.Tensor,
- input_pos: torch.Tensor,
- num_new_tokens: int,
- semantic_ids: list,
- im_end_id: int = 4,
- decode_one_token=decode_one_token_naive_agent,
- early_stop_threshold: float = 0.6,
- **sampling_kwargs,
- ):
- batch_size = cur_token.size(0)
- previous_tokens = torch.zeros(
- (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
- dtype=torch.int,
- device=cur_token.device,
- )
- finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
- finished = finished | (cur_token[:, 0, -1] == im_end_id)
- start_time = time.time()
- for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
- # We need to get windowed repeat penalty
- win_size = 16
- if i < win_size:
- window = previous_tokens[:, :, :win_size]
- else:
- window = previous_tokens[:, :, i - win_size : i]
- with sdpa_kernel(
- SDPBackend.MATH
- ): # Actually better for Inductor to codegen attention here
- next_token = decode_one_token(
- model=model,
- x=cur_token,
- input_pos=input_pos,
- previous_tokens=window,
- semantic_ids=semantic_ids,
- **sampling_kwargs,
- )
- input_pos += 1
- cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
- previous_tokens[:, :, i : i + 1] = next_token.view(
- batch_size, model.config.num_codebooks + 1, -1
- )
- yield cur_token.cpu()
- finished = finished | (cur_token[:, 0, -1] == im_end_id)
- if finished.all() or (
- 0 < early_stop_threshold < 1
- and finished.sum() >= round(batch_size * early_stop_threshold)
- ):
- break
- total_time = time.time() - start_time
- generated_tokens = i + 1
- tokens_per_second = (generated_tokens / total_time) * batch_size
- logger.info(
- f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
- )
- @torch.no_grad()
- @torch.inference_mode()
- def generate_agent(
- *,
- model: BaseTransformer,
- prompt: torch.Tensor,
- max_new_tokens: int,
- semantic_ids: list,
- im_end_id: int = 4,
- decode_one_token=decode_one_token_naive_agent,
- num_samples: int = 1,
- early_stop_threshold: float = 0.6,
- **sampling_kwargs,
- ):
- """
- Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
- """
- # create an empty tensor of the expected final shape and fill in the current tokens
- T = prompt.size(1)
- prompt = prompt[None].repeat(num_samples, 1, 1)
- if T >= model.config.max_seq_len:
- raise ValueError(
- f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
- )
- if max_new_tokens:
- if T + max_new_tokens > model.config.max_seq_len:
- max_new_tokens = model.config.max_seq_len - T
- logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
- T_new = T + max_new_tokens
- else:
- T_new = model.config.max_seq_len
- max_new_tokens = T_new - T
- device, dtype = prompt.device, prompt.dtype
- codebook_dim = 1 + model.config.num_codebooks
- input_pos = torch.arange(0, T, device=device)
- # Use non-accelerated version for now, to avoid compilation overhead
- prefill_decode = (
- decode_one_token_naive_agent
- if isinstance(model, NaiveTransformer)
- else decode_one_token_ar_agent
- )
- next_token = prefill_decode(
- model,
- prompt,
- input_pos,
- semantic_ids=semantic_ids,
- **sampling_kwargs,
- ).view(num_samples, codebook_dim, -1)
- yield next_token.cpu()
- input_pos = torch.tensor([T], device=device, dtype=torch.int)
- yield from decode_n_tokens_agent(
- model,
- next_token,
- input_pos,
- max_new_tokens - 1,
- im_end_id=im_end_id,
- semantic_ids=semantic_ids,
- decode_one_token=decode_one_token,
- early_stop_threshold=early_stop_threshold,
- **sampling_kwargs,
- )
- def encode_tokens(
- tokenizer,
- string,
- device="cuda",
- prompt_tokens=None,
- num_codebooks=4,
- ):
- string = clean_text(string)
- messages = []
- messages.append(
- Message(
- role="user",
- parts=[TextPart(text=string)],
- cal_loss=False,
- )
- )
- if prompt_tokens is not None:
- if prompt_tokens.ndim == 3:
- assert (
- prompt_tokens.shape[0] == 1
- ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
- prompt_tokens = prompt_tokens[0]
- assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
- if prompt_tokens.shape[0] > num_codebooks:
- logger.warning(
- f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
- )
- prompt_tokens = prompt_tokens[:num_codebooks]
- vq_part = VQPart(codes=prompt_tokens.to(device))
- messages.append(
- Message(
- role="assistant",
- parts=[TextPart(text="<|voice|>"), vq_part],
- cal_loss=False,
- )
- )
- else:
- messages.append(
- Message(
- role="assistant",
- parts=[TextPart(text="<|voice|>")],
- cal_loss=False,
- add_im_end=False,
- )
- )
- conversation = Conversation(messages=messages)
- # conversation.visualize(tokenizer)
- encoded = conversation.encode_for_inference(
- tokenizer=tokenizer,
- num_codebooks=num_codebooks,
- )
- return encoded.to(device)
- def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
- model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
- checkpoint_path, load_weights=True, is_agent=is_agent
- )
- model = model.to(device=device, dtype=precision)
- logger.info(f"Restored model from checkpoint")
- if isinstance(model, DualARTransformer):
- decode_one_token = (
- decode_one_token_ar_agent if is_agent else decode_one_token_ar
- )
- logger.info("Using DualARTransformer")
- else:
- decode_one_token = (
- decode_one_token_naive_agent if is_agent else decode_one_token_naive
- )
- logger.info("Using NaiveTransformer")
- if compile:
- logger.info("Compiling function...")
- decode_one_token = torch.compile(
- decode_one_token,
- fullgraph=True,
- backend="inductor" if torch.cuda.is_available() else "aot_eager",
- mode="reduce-overhead" if torch.cuda.is_available() else None,
- )
- return model.eval(), decode_one_token
- @dataclass
- class GenerateResponse:
- action: Literal["sample", "next"]
- codes: Optional[torch.Tensor] = None
- text: Optional[str] = None
- def generate_long(
- *,
- model,
- device: str | torch.device,
- decode_one_token: callable,
- text: str,
- num_samples: int = 1,
- max_new_tokens: int = 0,
- top_p: int = 0.7,
- repetition_penalty: float = 1.5,
- temperature: float = 0.7,
- compile: bool = False,
- iterative_prompt: bool = True,
- max_length: int = 2048,
- chunk_length: int = 150,
- prompt_text: Optional[str | list[str]] = None,
- prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
- ):
- assert 0 < top_p <= 1, "top_p must be in (0, 1]"
- assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
- assert 0 < temperature < 2, "temperature must be in (0, 2)"
- use_prompt = prompt_text is not None and prompt_tokens is not None
- if use_prompt and isinstance(prompt_text, str):
- prompt_text = [prompt_text]
- prompt_tokens = [prompt_tokens]
- assert use_prompt is False or len(prompt_text) == len(
- prompt_tokens
- ), "Prompt text and tokens must have the same length"
- model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
- tokenizer = model.tokenizer
- im_end_id = tokenizer.get_token_id("<|im_end|>")
- encoded = []
- texts = split_text(text, chunk_length) if iterative_prompt else [text]
- encoded_prompts = [
- Conversation(
- messages=[
- Message(
- role="system",
- parts=[TextPart(text="Speak out the provided text.")],
- cal_loss=False,
- )
- ]
- )
- .encode_for_inference(
- tokenizer=tokenizer,
- num_codebooks=model.config.num_codebooks,
- )
- .to(device)
- ]
- if use_prompt:
- for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
- encoded_prompts.append(
- encode_tokens(
- tokenizer,
- string=t,
- device=device,
- prompt_tokens=c,
- num_codebooks=model.config.num_codebooks,
- )
- )
- for idx, text in enumerate(texts):
- encoded.append(
- encode_tokens(
- tokenizer,
- string=text,
- device=device,
- num_codebooks=model.config.num_codebooks,
- )
- )
- logger.info(f"Encoded text: {text}")
- # Move temperature, top_p, repetition_penalty to device
- # This is important so that changing params doesn't trigger recompile
- temperature = torch.tensor(temperature, device=device, dtype=torch.float)
- top_p = torch.tensor(top_p, device=device, dtype=torch.float)
- repetition_penalty = torch.tensor(
- repetition_penalty, device=device, dtype=torch.float
- )
- for sample_idx in range(num_samples):
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- global_encoded = []
- seg_idx = 0
- while seg_idx < len(encoded):
- logger.info(
- f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
- )
- seg = encoded[seg_idx]
- global_encoded.append(seg)
- lengths = reversed([seg.size(1) for seg in global_encoded])
- # Pick last 2000 tokens
- count = 0
- for i, length in enumerate(lengths):
- count += length
- if count + length > max_length - 1024 - sum(
- t.shape[1] for t in encoded_prompts
- ):
- break
- if i != 0 and i % 2 == 0:
- i -= 1
- # Rotate the list, always make sure first segment is included to avoid drift
- if i < len(global_encoded) - 2:
- partial_encoded = global_encoded[:2] + global_encoded[-i:]
- else:
- partial_encoded = global_encoded
- if use_prompt:
- partial_encoded = encoded_prompts + partial_encoded
- cat_encoded = torch.cat(partial_encoded, dim=1)
- prompt_length = cat_encoded.size(1)
- t0 = time.perf_counter()
- y = generate(
- model=model,
- prompt=cat_encoded,
- max_new_tokens=max_new_tokens,
- decode_one_token=decode_one_token,
- temperature=temperature,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- )
- if sample_idx == 0 and seg_idx == 0 and compile:
- logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- t = time.perf_counter() - t0
- tokens_generated = y.size(1) - prompt_length
- tokens_sec = tokens_generated / t
- logger.info(
- f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
- )
- logger.info(
- f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
- )
- if torch.cuda.is_available():
- logger.info(
- f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
- )
- # Put the generated tokens
- # since there is <im_end>, we remove last token
- codes = y[1:, prompt_length + 1 :].clone()
- assert (codes >= 0).all(), f"Negative code found"
- decoded = y[:, prompt_length:].clone()
- # But for global encoding, we should keep the <im_end> token
- global_encoded.append(decoded)
- assert (codes >= 0).all(), f"Negative code found: {codes}"
- yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
- seg_idx += 1
- # This indicates the end of the current sample
- yield GenerateResponse(action="next")
- @dataclass
- class WrappedGenerateResponse:
- status: Literal["success", "error"]
- response: Optional[GenerateResponse | Exception] = None
- @dataclass
- class GenerateRequest:
- request: dict
- response_queue: queue.Queue
- def launch_thread_safe_queue(
- checkpoint_path,
- device,
- precision,
- compile: bool = False,
- ):
- input_queue = queue.Queue()
- init_event = threading.Event()
- def worker():
- model, decode_one_token = load_model(
- checkpoint_path, device, precision, compile=compile
- )
- with torch.device(device):
- model.setup_caches(
- max_batch_size=1,
- max_seq_len=model.config.max_seq_len,
- dtype=next(model.parameters()).dtype,
- )
- init_event.set()
- while True:
- item: GenerateRequest | None = input_queue.get()
- if item is None:
- break
- kwargs = item.request
- response_queue = item.response_queue
- try:
- for chunk in generate_long(
- model=model, decode_one_token=decode_one_token, **kwargs
- ):
- response_queue.put(
- WrappedGenerateResponse(status="success", response=chunk)
- )
- except Exception as e:
- response_queue.put(WrappedGenerateResponse(status="error", response=e))
- threading.Thread(target=worker, daemon=True).start()
- init_event.wait()
- return input_queue
- def launch_thread_safe_queue_agent(
- checkpoint_path,
- device,
- precision,
- compile: bool = False,
- ):
- input_queue = queue.Queue()
- init_event = threading.Event()
- tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
- config = BaseModelArgs.from_pretrained(checkpoint_path)
- def worker():
- model, decode_one_token = load_model(
- checkpoint_path, device, precision, compile=compile, is_agent=True
- )
- with torch.device(device):
- model.setup_caches(
- max_batch_size=1,
- max_seq_len=model.config.max_seq_len,
- dtype=next(model.parameters()).dtype,
- )
- init_event.set()
- while True:
- item: GenerateRequest | None = input_queue.get()
- if item is None:
- break
- kwargs = item.request
- response_queue = item.response_queue
- try:
- for token in generate_agent(
- model=model,
- decode_one_token=decode_one_token,
- **kwargs,
- ):
- response_queue.put(token)
- response_queue.put("stop")
- except Exception as e:
- import traceback
- logger.exception(f"Error in worker: {traceback.format_exc()}")
- response_queue.put("error")
- threading.Thread(target=worker, daemon=True).start()
- init_event.wait()
- return input_queue, tokenizer, config
- @click.command()
- @click.option(
- "--text",
- type=str,
- default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
- )
- @click.option("--prompt-text", type=str, default=None, multiple=True)
- @click.option(
- "--prompt-tokens",
- type=click.Path(path_type=Path, exists=True),
- default=None,
- multiple=True,
- )
- @click.option("--num-samples", type=int, default=1)
- @click.option("--max-new-tokens", type=int, default=0)
- @click.option("--top-p", type=float, default=0.7)
- @click.option("--repetition-penalty", type=float, default=1.2)
- @click.option("--temperature", type=float, default=0.7)
- @click.option(
- "--checkpoint-path",
- type=click.Path(path_type=Path, exists=True),
- default="checkpoints/fish-speech-1.5",
- )
- @click.option("--device", type=str, default="cuda")
- @click.option("--compile/--no-compile", default=False)
- @click.option("--seed", type=int, default=42)
- @click.option("--half/--no-half", default=False)
- @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
- @click.option("--chunk-length", type=int, default=100)
- @click.option("--output-dir", type=Path, default="temp")
- def main(
- text: str,
- prompt_text: Optional[list[str]],
- prompt_tokens: Optional[list[Path]],
- num_samples: int,
- max_new_tokens: int,
- top_p: int,
- repetition_penalty: float,
- temperature: float,
- checkpoint_path: Path,
- device: str,
- compile: bool,
- seed: int,
- half: bool,
- iterative_prompt: bool,
- chunk_length: int,
- output_dir: Path,
- ) -> None:
- os.makedirs(output_dir, exist_ok=True)
- precision = torch.half if half else torch.bfloat16
- if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
- raise ValueError(
- f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
- )
- logger.info("Loading model ...")
- t0 = time.time()
- model, decode_one_token = load_model(
- checkpoint_path, device, precision, compile=compile
- )
- with torch.device(device):
- model.setup_caches(
- max_batch_size=1,
- max_seq_len=model.config.max_seq_len,
- dtype=next(model.parameters()).dtype,
- )
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
- if prompt_tokens is not None:
- prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- generator = generate_long(
- model=model,
- device=device,
- decode_one_token=decode_one_token,
- text=text,
- num_samples=num_samples,
- max_new_tokens=max_new_tokens,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- temperature=temperature,
- compile=compile,
- iterative_prompt=iterative_prompt,
- chunk_length=chunk_length,
- prompt_text=prompt_text,
- prompt_tokens=prompt_tokens,
- )
- idx = 0
- codes = []
- for response in generator:
- if response.action == "sample":
- codes.append(response.codes)
- logger.info(f"Sampled text: {response.text}")
- elif response.action == "next":
- if codes:
- codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
- np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
- logger.info(f"Saved codes to {codes_npy_path}")
- logger.info(f"Next sample")
- codes = []
- idx += 1
- else:
- logger.error(f"Error: {response}")
- if __name__ == "__main__":
- main()
|