|
|
@@ -16,10 +16,8 @@ from loguru import logger
|
|
|
from tqdm import tqdm
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
-from fish_speech.conversation import (
|
|
|
- CODEBOOK_PAD_TOKEN_ID,
|
|
|
- Conversation,
|
|
|
- Message,
|
|
|
+from fish_speech.content_sequence import (
|
|
|
+ ContentSequence,
|
|
|
TextPart,
|
|
|
VQPart,
|
|
|
)
|
|
|
@@ -84,45 +82,6 @@ def logits_to_probs(
|
|
|
return probs
|
|
|
|
|
|
|
|
|
-def multinomial_sample_one_no_sync_agent(
|
|
|
- probs_sort,
|
|
|
-): # Does multinomial sampling without a cuda synchronization
|
|
|
- q = torch.empty_like(probs_sort).exponential_(1)
|
|
|
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
|
-
|
|
|
-
|
|
|
-def logits_to_probs_agent(
|
|
|
- logits,
|
|
|
- previous_tokens: Optional[torch.Tensor] = None,
|
|
|
- temperature: torch.Tensor = 1.0,
|
|
|
- top_p: torch.Tensor = 1.0,
|
|
|
- repetition_penalty: torch.Tensor = 1.0,
|
|
|
-) -> torch.Tensor:
|
|
|
- # Apply repetition penalty
|
|
|
- if previous_tokens is not None:
|
|
|
- previous_tokens = previous_tokens.long()
|
|
|
- score = torch.gather(logits, dim=-1, index=previous_tokens)
|
|
|
- score = torch.where(
|
|
|
- score < 0, score * repetition_penalty, score / repetition_penalty
|
|
|
- )
|
|
|
- logits.scatter_(dim=-1, index=previous_tokens, src=score)
|
|
|
-
|
|
|
- # Apply top-p sampling
|
|
|
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
- cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
- sorted_indices_to_remove = cum_probs > top_p
|
|
|
- sorted_indices_to_remove[..., 0] = False # keep at least one option
|
|
|
- indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
- dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
|
|
- )
|
|
|
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
|
-
|
|
|
- logits = logits / max(temperature, 1e-5)
|
|
|
-
|
|
|
- probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
|
- return probs
|
|
|
-
|
|
|
-
|
|
|
def sample(
|
|
|
logits,
|
|
|
previous_tokens: Optional[torch.Tensor] = None,
|
|
|
@@ -135,117 +94,6 @@ def sample(
|
|
|
return idx_next, probs
|
|
|
|
|
|
|
|
|
-def sample_agent(
|
|
|
- logits,
|
|
|
- previous_tokens: Optional[torch.Tensor] = None,
|
|
|
- **sampling_kwargs,
|
|
|
-) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
- probs = logits_to_probs_agent(
|
|
|
- logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
|
|
|
- )
|
|
|
- idx_next = multinomial_sample_one_no_sync_agent(probs)
|
|
|
- return idx_next, probs
|
|
|
-
|
|
|
-
|
|
|
-def decode_one_token_ar_agent(
|
|
|
- model: DualARTransformer,
|
|
|
- x: torch.Tensor,
|
|
|
- input_pos: torch.Tensor,
|
|
|
- semantic_ids: list,
|
|
|
- previous_tokens: torch.Tensor = None,
|
|
|
- **sampling_kwargs,
|
|
|
-) -> torch.Tensor:
|
|
|
- # print(x, input_pos)
|
|
|
- x = model.forward_generate(x, input_pos)
|
|
|
- logits = x.logits # [:, -1:]
|
|
|
- hidden_states = x.hidden_states # [:, -1:]
|
|
|
-
|
|
|
- sampling_kwargs_main = sampling_kwargs.copy()
|
|
|
- sampling_kwargs_main["temperature"] = 0.1
|
|
|
- sampling_kwargs_main["top_p"] = 0.1
|
|
|
- sampling_kwargs_main["repetition_penalty"] = 1.0
|
|
|
-
|
|
|
- codebooks = [
|
|
|
- sample_agent(
|
|
|
- logits,
|
|
|
- previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
|
- **sampling_kwargs_main,
|
|
|
- )[0]
|
|
|
- ]
|
|
|
-
|
|
|
- # Cleanup the cache
|
|
|
- for layer in model.fast_layers:
|
|
|
- layer.attention.kv_cache.k_cache.fill_(0)
|
|
|
- layer.attention.kv_cache.v_cache.fill_(0)
|
|
|
-
|
|
|
- for codebook_idx in range(model.config.num_codebooks):
|
|
|
- input_pos = torch.tensor(
|
|
|
- [codebook_idx], device=hidden_states.device, dtype=torch.long
|
|
|
- )
|
|
|
- logits = model.forward_generate_fast(hidden_states, input_pos)
|
|
|
- a = sample_agent(
|
|
|
- logits,
|
|
|
- previous_tokens=(
|
|
|
- previous_tokens[:, codebook_idx + 1]
|
|
|
- if previous_tokens is not None
|
|
|
- else None
|
|
|
- ),
|
|
|
- **sampling_kwargs,
|
|
|
- )[0]
|
|
|
- hidden_states = model.fast_embeddings(a)
|
|
|
- codebooks.append(a)
|
|
|
-
|
|
|
- codebooks = torch.stack(codebooks, dim=1)
|
|
|
- semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
|
|
- codebooks[:, 1:, :] = torch.masked_fill(
|
|
|
- codebooks[:, 1:, :],
|
|
|
- ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
|
|
|
- CODEBOOK_PAD_TOKEN_ID,
|
|
|
- )
|
|
|
-
|
|
|
- return codebooks
|
|
|
-
|
|
|
-
|
|
|
-def decode_one_token_naive_agent(
|
|
|
- model: NaiveTransformer,
|
|
|
- x: torch.Tensor,
|
|
|
- input_pos: torch.Tensor,
|
|
|
- semantic_ids: list,
|
|
|
- previous_tokens: torch.Tensor = None,
|
|
|
- **sampling_kwargs,
|
|
|
-) -> torch.Tensor:
|
|
|
- x = model.forward_generate(x, input_pos)
|
|
|
-
|
|
|
- codebooks = [
|
|
|
- sample(
|
|
|
- x.token_logits,
|
|
|
- previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
|
- **sampling_kwargs,
|
|
|
- )[0]
|
|
|
- ]
|
|
|
-
|
|
|
- for i in range(model.config.num_codebooks):
|
|
|
- codebooks.append(
|
|
|
- sample_agent(
|
|
|
- x.codebook_logits[:, :, i],
|
|
|
- previous_tokens=(
|
|
|
- previous_tokens[:, i + 1] if previous_tokens is not None else None
|
|
|
- ),
|
|
|
- **sampling_kwargs,
|
|
|
- )[0]
|
|
|
- )
|
|
|
-
|
|
|
- codebooks = torch.stack(codebooks, dim=1)
|
|
|
- semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
|
|
|
- codebooks[:, 1:, :] = torch.masked_fill(
|
|
|
- codebooks[:, 1:, :],
|
|
|
- ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
|
|
|
- CODEBOOK_PAD_TOKEN_ID,
|
|
|
- )
|
|
|
-
|
|
|
- return codebooks
|
|
|
-
|
|
|
-
|
|
|
def decode_one_token_ar(
|
|
|
model: DualARTransformer,
|
|
|
x: torch.Tensor,
|
|
|
@@ -290,8 +138,9 @@ def decode_one_token_ar(
|
|
|
[codebook_idx], device=hidden_states.device, dtype=torch.long
|
|
|
)
|
|
|
logits = model.forward_generate_fast(hidden_states, input_pos)
|
|
|
+ chunked_logits = logits[..., :1024]
|
|
|
a = sample(
|
|
|
- logits,
|
|
|
+ chunked_logits,
|
|
|
previous_tokens=(
|
|
|
previous_tokens[codebook_idx + 1]
|
|
|
if previous_tokens is not None
|
|
|
@@ -312,49 +161,13 @@ def decode_one_token_ar(
|
|
|
return codebooks
|
|
|
|
|
|
|
|
|
-def decode_one_token_naive(
|
|
|
- model: NaiveTransformer,
|
|
|
- x: torch.Tensor,
|
|
|
- input_pos: torch.Tensor,
|
|
|
- previous_tokens: torch.Tensor = None,
|
|
|
- **sampling_kwargs,
|
|
|
-) -> torch.Tensor:
|
|
|
- x = model.forward_generate(x, input_pos)
|
|
|
-
|
|
|
- sampling_kwargs_main = sampling_kwargs.copy()
|
|
|
- sampling_kwargs_main["temperature"] = 0.1
|
|
|
- sampling_kwargs_main["top_p"] = 0.1
|
|
|
- sampling_kwargs_main["repetition_penalty"] = 1.0
|
|
|
-
|
|
|
- codebooks = [
|
|
|
- sample(
|
|
|
- x.logits,
|
|
|
- previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
|
- **sampling_kwargs_main,
|
|
|
- )[0]
|
|
|
- ]
|
|
|
-
|
|
|
- for i in range(model.config.num_codebooks):
|
|
|
- codebooks.append(
|
|
|
- sample(
|
|
|
- x.codebook_logits[:, :, i],
|
|
|
- previous_tokens=(
|
|
|
- previous_tokens[i + 1] if previous_tokens is not None else None
|
|
|
- ),
|
|
|
- **sampling_kwargs,
|
|
|
- )[0]
|
|
|
- )
|
|
|
-
|
|
|
- return torch.stack(codebooks, dim=0)
|
|
|
-
|
|
|
-
|
|
|
def decode_n_tokens(
|
|
|
model: NaiveTransformer,
|
|
|
cur_token: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
|
num_new_tokens: int,
|
|
|
semantic_ids: list,
|
|
|
- decode_one_token=decode_one_token_naive,
|
|
|
+ decode_one_token=decode_one_token_ar,
|
|
|
**sampling_kwargs,
|
|
|
):
|
|
|
previous_tokens = torch.zeros(
|
|
|
@@ -406,7 +219,7 @@ def generate(
|
|
|
model: NaiveTransformer,
|
|
|
prompt: torch.Tensor,
|
|
|
max_new_tokens: int,
|
|
|
- decode_one_token=decode_one_token_naive,
|
|
|
+ decode_one_token=decode_one_token_ar,
|
|
|
**sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
@@ -442,11 +255,7 @@ def generate(
|
|
|
input_pos = torch.arange(0, T, device=device)
|
|
|
|
|
|
# Use non-accelerated version for now, to avoid compilation overhead
|
|
|
- prefill_decode = (
|
|
|
- decode_one_token_naive
|
|
|
- if isinstance(model, NaiveTransformer)
|
|
|
- else decode_one_token_ar
|
|
|
- )
|
|
|
+ prefill_decode = decode_one_token_ar
|
|
|
|
|
|
next_token = prefill_decode(
|
|
|
model,
|
|
|
@@ -474,222 +283,17 @@ def generate(
|
|
|
return seq
|
|
|
|
|
|
|
|
|
-def decode_n_tokens_agent(
|
|
|
- model: NaiveTransformer,
|
|
|
- cur_token: torch.Tensor,
|
|
|
- input_pos: torch.Tensor,
|
|
|
- num_new_tokens: int,
|
|
|
- semantic_ids: list,
|
|
|
- im_end_id: int = 4,
|
|
|
- decode_one_token=decode_one_token_naive_agent,
|
|
|
- early_stop_threshold: float = 0.6,
|
|
|
- **sampling_kwargs,
|
|
|
-):
|
|
|
- batch_size = cur_token.size(0)
|
|
|
- previous_tokens = torch.zeros(
|
|
|
- (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
|
|
|
- dtype=torch.int,
|
|
|
- device=cur_token.device,
|
|
|
- )
|
|
|
- finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
|
|
|
- finished = finished | (cur_token[:, 0, -1] == im_end_id)
|
|
|
- start_time = time.time()
|
|
|
-
|
|
|
- for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
|
|
|
- # We need to get windowed repeat penalty
|
|
|
- win_size = 16
|
|
|
- if i < win_size:
|
|
|
- window = previous_tokens[:, :, :win_size]
|
|
|
- else:
|
|
|
- window = previous_tokens[:, :, i - win_size : i]
|
|
|
-
|
|
|
- with sdpa_kernel(
|
|
|
- SDPBackend.MATH
|
|
|
- ): # Actually better for Inductor to codegen attention here
|
|
|
- next_token = decode_one_token(
|
|
|
- model=model,
|
|
|
- x=cur_token,
|
|
|
- input_pos=input_pos,
|
|
|
- previous_tokens=window,
|
|
|
- semantic_ids=semantic_ids,
|
|
|
- **sampling_kwargs,
|
|
|
- )
|
|
|
-
|
|
|
- input_pos += 1
|
|
|
- cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
|
|
|
- previous_tokens[:, :, i : i + 1] = next_token.view(
|
|
|
- batch_size, model.config.num_codebooks + 1, -1
|
|
|
- )
|
|
|
-
|
|
|
- yield cur_token.cpu()
|
|
|
-
|
|
|
- finished = finished | (cur_token[:, 0, -1] == im_end_id)
|
|
|
- if finished.all() or (
|
|
|
- 0 < early_stop_threshold < 1
|
|
|
- and finished.sum() >= round(batch_size * early_stop_threshold)
|
|
|
- ):
|
|
|
- break
|
|
|
-
|
|
|
- total_time = time.time() - start_time
|
|
|
- generated_tokens = i + 1
|
|
|
- tokens_per_second = (generated_tokens / total_time) * batch_size
|
|
|
- logger.info(
|
|
|
- f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-@torch.no_grad()
|
|
|
-@torch.inference_mode()
|
|
|
-def generate_agent(
|
|
|
- *,
|
|
|
- model: BaseTransformer,
|
|
|
- prompt: torch.Tensor,
|
|
|
- max_new_tokens: int,
|
|
|
- semantic_ids: list,
|
|
|
- im_end_id: int = 4,
|
|
|
- decode_one_token=decode_one_token_naive_agent,
|
|
|
- num_samples: int = 1,
|
|
|
- early_stop_threshold: float = 0.6,
|
|
|
- **sampling_kwargs,
|
|
|
-):
|
|
|
- """
|
|
|
- Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
|
|
- """
|
|
|
-
|
|
|
- # create an empty tensor of the expected final shape and fill in the current tokens
|
|
|
- T = prompt.size(1)
|
|
|
- prompt = prompt[None].repeat(num_samples, 1, 1)
|
|
|
-
|
|
|
- if T >= model.config.max_seq_len:
|
|
|
- raise ValueError(
|
|
|
- f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
|
|
|
- )
|
|
|
-
|
|
|
- if max_new_tokens:
|
|
|
- if T + max_new_tokens > model.config.max_seq_len:
|
|
|
- max_new_tokens = model.config.max_seq_len - T
|
|
|
- logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
|
|
-
|
|
|
- T_new = T + max_new_tokens
|
|
|
- else:
|
|
|
- T_new = model.config.max_seq_len
|
|
|
- max_new_tokens = T_new - T
|
|
|
-
|
|
|
- device, dtype = prompt.device, prompt.dtype
|
|
|
-
|
|
|
- codebook_dim = 1 + model.config.num_codebooks
|
|
|
- input_pos = torch.arange(0, T, device=device)
|
|
|
-
|
|
|
- # Use non-accelerated version for now, to avoid compilation overhead
|
|
|
- prefill_decode = (
|
|
|
- decode_one_token_naive_agent
|
|
|
- if isinstance(model, NaiveTransformer)
|
|
|
- else decode_one_token_ar_agent
|
|
|
- )
|
|
|
- next_token = prefill_decode(
|
|
|
- model,
|
|
|
- prompt,
|
|
|
- input_pos,
|
|
|
- semantic_ids=semantic_ids,
|
|
|
- **sampling_kwargs,
|
|
|
- ).view(num_samples, codebook_dim, -1)
|
|
|
- yield next_token.cpu()
|
|
|
-
|
|
|
- input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
|
|
-
|
|
|
- yield from decode_n_tokens_agent(
|
|
|
- model,
|
|
|
- next_token,
|
|
|
- input_pos,
|
|
|
- max_new_tokens - 1,
|
|
|
- im_end_id=im_end_id,
|
|
|
- semantic_ids=semantic_ids,
|
|
|
- decode_one_token=decode_one_token,
|
|
|
- early_stop_threshold=early_stop_threshold,
|
|
|
- **sampling_kwargs,
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-def encode_tokens(
|
|
|
- tokenizer,
|
|
|
- string,
|
|
|
- device="cuda",
|
|
|
- prompt_tokens=None,
|
|
|
- num_codebooks=4,
|
|
|
-):
|
|
|
- string = clean_text(string)
|
|
|
-
|
|
|
- messages = []
|
|
|
- messages.append(
|
|
|
- Message(
|
|
|
- role="user",
|
|
|
- parts=[TextPart(text=string)],
|
|
|
- cal_loss=False,
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- if prompt_tokens is not None:
|
|
|
- if prompt_tokens.ndim == 3:
|
|
|
- assert (
|
|
|
- prompt_tokens.shape[0] == 1
|
|
|
- ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
|
|
|
- prompt_tokens = prompt_tokens[0]
|
|
|
-
|
|
|
- assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
|
|
|
-
|
|
|
- if prompt_tokens.shape[0] > num_codebooks:
|
|
|
- logger.warning(
|
|
|
- f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
|
|
|
- )
|
|
|
- prompt_tokens = prompt_tokens[:num_codebooks]
|
|
|
-
|
|
|
- vq_part = VQPart(codes=prompt_tokens.to(device))
|
|
|
-
|
|
|
- messages.append(
|
|
|
- Message(
|
|
|
- role="assistant",
|
|
|
- parts=[TextPart(text="<|voice|>"), vq_part],
|
|
|
- cal_loss=False,
|
|
|
- )
|
|
|
- )
|
|
|
- else:
|
|
|
- messages.append(
|
|
|
- Message(
|
|
|
- role="assistant",
|
|
|
- parts=[TextPart(text="<|voice|>")],
|
|
|
- cal_loss=False,
|
|
|
- add_im_end=False,
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- conversation = Conversation(messages=messages)
|
|
|
- # conversation.visualize(tokenizer)
|
|
|
- encoded = conversation.encode_for_inference(
|
|
|
- tokenizer=tokenizer,
|
|
|
- num_codebooks=num_codebooks,
|
|
|
- )
|
|
|
-
|
|
|
- return encoded.to(device)
|
|
|
-
|
|
|
-
|
|
|
-def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
|
|
|
- model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
|
|
|
- checkpoint_path, load_weights=True, is_agent=is_agent
|
|
|
- )
|
|
|
+def load_model(checkpoint_path, device, precision, compile=False):
|
|
|
+ model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
|
|
|
|
|
|
model = model.to(device=device, dtype=precision)
|
|
|
logger.info(f"Restored model from checkpoint")
|
|
|
|
|
|
if isinstance(model, DualARTransformer):
|
|
|
- decode_one_token = (
|
|
|
- decode_one_token_ar_agent if is_agent else decode_one_token_ar
|
|
|
- )
|
|
|
+ decode_one_token = decode_one_token_ar
|
|
|
logger.info("Using DualARTransformer")
|
|
|
else:
|
|
|
- decode_one_token = (
|
|
|
- decode_one_token_naive_agent if is_agent else decode_one_token_naive
|
|
|
- )
|
|
|
- logger.info("Using NaiveTransformer")
|
|
|
+ raise ValueError("Model is not a DualARTransformer")
|
|
|
|
|
|
if compile:
|
|
|
logger.info("Compiling function...")
|
|
|
@@ -723,7 +327,6 @@ def generate_long(
|
|
|
temperature: float = 0.7,
|
|
|
compile: bool = False,
|
|
|
iterative_prompt: bool = True,
|
|
|
- max_length: int = 2048,
|
|
|
chunk_length: int = 150,
|
|
|
prompt_text: Optional[str | list[str]] = None,
|
|
|
prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
|
|
|
@@ -743,46 +346,36 @@ def generate_long(
|
|
|
|
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
tokenizer = model.tokenizer
|
|
|
- im_end_id = tokenizer.get_token_id("<|im_end|>")
|
|
|
+ base_content_sequence = ContentSequence(modality="interleave")
|
|
|
|
|
|
- encoded = []
|
|
|
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
|
- encoded_prompts = [
|
|
|
- Conversation(
|
|
|
- messages=[
|
|
|
- Message(
|
|
|
- role="system",
|
|
|
- parts=[TextPart(text="Speak out the provided text.")],
|
|
|
- cal_loss=False,
|
|
|
- )
|
|
|
- ]
|
|
|
- )
|
|
|
- .encode_for_inference(
|
|
|
- tokenizer=tokenizer,
|
|
|
- num_codebooks=model.config.num_codebooks,
|
|
|
- )
|
|
|
- .to(device)
|
|
|
- ]
|
|
|
+ max_length = model.config.max_seq_len
|
|
|
|
|
|
if use_prompt:
|
|
|
- for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
|
|
|
- encoded_prompts.append(
|
|
|
- encode_tokens(
|
|
|
- tokenizer,
|
|
|
- string=t,
|
|
|
- device=device,
|
|
|
- prompt_tokens=c,
|
|
|
- num_codebooks=model.config.num_codebooks,
|
|
|
- )
|
|
|
+ for t, c in zip(prompt_text, prompt_tokens):
|
|
|
+ base_content_sequence.append(
|
|
|
+ [
|
|
|
+ TextPart(text=t),
|
|
|
+ VQPart(codes=c),
|
|
|
+ ],
|
|
|
+ add_end=True,
|
|
|
)
|
|
|
|
|
|
- for idx, text in enumerate(texts):
|
|
|
+ encoded_prompts = base_content_sequence.encode_for_inference(
|
|
|
+ tokenizer, num_codebooks=model.config.num_codebooks
|
|
|
+ )
|
|
|
+ if encoded_prompts.size(1) > max_length - 2048:
|
|
|
+ raise ValueError(
|
|
|
+ f"Prompt is too long: {encoded_prompts.size(1)} > {max_length - 2048}"
|
|
|
+ )
|
|
|
+
|
|
|
+ encoded = []
|
|
|
+ for text in texts:
|
|
|
+ content_sequence = ContentSequence(modality=None)
|
|
|
+ content_sequence.append(TextPart(text=text))
|
|
|
encoded.append(
|
|
|
- encode_tokens(
|
|
|
- tokenizer,
|
|
|
- string=text,
|
|
|
- device=device,
|
|
|
- num_codebooks=model.config.num_codebooks,
|
|
|
+ content_sequence.encode_for_inference(
|
|
|
+ tokenizer, num_codebooks=model.config.num_codebooks
|
|
|
)
|
|
|
)
|
|
|
logger.info(f"Encoded text: {text}")
|
|
|
@@ -810,30 +403,28 @@ def generate_long(
|
|
|
seg = encoded[seg_idx]
|
|
|
global_encoded.append(seg)
|
|
|
|
|
|
- lengths = reversed([seg.size(1) for seg in global_encoded])
|
|
|
-
|
|
|
- # Pick last 2000 tokens
|
|
|
- count = 0
|
|
|
- for i, length in enumerate(lengths):
|
|
|
- count += length
|
|
|
- if count + length > max_length - 1024 - sum(
|
|
|
- t.shape[1] for t in encoded_prompts
|
|
|
- ):
|
|
|
- break
|
|
|
+ # Do not use previous segments to generate current segment for now
|
|
|
+ # lengths = reversed([seg.size(1) for seg in global_encoded])
|
|
|
|
|
|
- if i != 0 and i % 2 == 0:
|
|
|
- i -= 1
|
|
|
+ # # Pick last 2000 tokens
|
|
|
+ # count = 0
|
|
|
+ # for i, length in enumerate(lengths):
|
|
|
+ # count += length
|
|
|
+ # if count + length > max_length - 2048 - encoded_prompts.size(1):
|
|
|
+ # break
|
|
|
|
|
|
- # Rotate the list, always make sure first segment is included to avoid drift
|
|
|
- if i < len(global_encoded) - 2:
|
|
|
- partial_encoded = global_encoded[:2] + global_encoded[-i:]
|
|
|
- else:
|
|
|
- partial_encoded = global_encoded
|
|
|
+ # if i != 0 and i % 2 == 0:
|
|
|
+ # i -= 1
|
|
|
|
|
|
- if use_prompt:
|
|
|
- partial_encoded = encoded_prompts + partial_encoded
|
|
|
+ # # Rotate the list, always make sure first segment is included to avoid drift
|
|
|
+ # if i < len(global_encoded) - 2:
|
|
|
+ # partial_encoded = global_encoded[:2] + global_encoded[-i:]
|
|
|
+ # else:
|
|
|
+ # partial_encoded = global_encoded
|
|
|
|
|
|
- cat_encoded = torch.cat(partial_encoded, dim=1)
|
|
|
+ # cat_encoded = torch.cat([encoded_prompts, *partial_encoded], dim=1)
|
|
|
+ cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
|
|
|
+ cat_encoded = cat_encoded.to(device=device)
|
|
|
prompt_length = cat_encoded.size(1)
|
|
|
|
|
|
t0 = time.perf_counter()
|
|
|
@@ -871,13 +462,13 @@ def generate_long(
|
|
|
|
|
|
# Put the generated tokens
|
|
|
# since there is <im_end>, we remove last token
|
|
|
- codes = y[1:, prompt_length + 1 :].clone()
|
|
|
+ codes = y[1:, prompt_length:-1].clone()
|
|
|
assert (codes >= 0).all(), f"Negative code found"
|
|
|
|
|
|
decoded = y[:, prompt_length:].clone()
|
|
|
# But for global encoding, we should keep the <im_end> token
|
|
|
|
|
|
- global_encoded.append(decoded)
|
|
|
+ global_encoded.append(decoded.cpu())
|
|
|
assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
|
|
|
seg_idx += 1
|
|
|
@@ -1012,20 +603,20 @@ def launch_thread_safe_queue_agent(
|
|
|
)
|
|
|
@click.option("--num-samples", type=int, default=1)
|
|
|
@click.option("--max-new-tokens", type=int, default=0)
|
|
|
-@click.option("--top-p", type=float, default=0.7)
|
|
|
-@click.option("--repetition-penalty", type=float, default=1.2)
|
|
|
-@click.option("--temperature", type=float, default=0.7)
|
|
|
+@click.option("--top-p", type=float, default=0.8)
|
|
|
+@click.option("--repetition-penalty", type=float, default=1.1)
|
|
|
+@click.option("--temperature", type=float, default=0.8)
|
|
|
@click.option(
|
|
|
"--checkpoint-path",
|
|
|
type=click.Path(path_type=Path, exists=True),
|
|
|
- default="checkpoints/fish-speech-1.5",
|
|
|
+ default="checkpoints/openaudio-s1-mini",
|
|
|
)
|
|
|
@click.option("--device", type=str, default="cuda")
|
|
|
@click.option("--compile/--no-compile", default=False)
|
|
|
@click.option("--seed", type=int, default=42)
|
|
|
@click.option("--half/--no-half", default=False)
|
|
|
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
|
|
-@click.option("--chunk-length", type=int, default=100)
|
|
|
+@click.option("--chunk-length", type=int, default=300)
|
|
|
@click.option("--output-dir", type=Path, default="temp")
|
|
|
def main(
|
|
|
text: str,
|
|
|
@@ -1070,7 +661,7 @@ def main(
|
|
|
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
|
|
|
|
|
if prompt_tokens is not None:
|
|
|
- prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
|
|
|
+ prompt_tokens = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|