|
|
@@ -2,6 +2,7 @@ import os
|
|
|
import queue
|
|
|
import threading
|
|
|
import time
|
|
|
+import traceback
|
|
|
from contextlib import nullcontext
|
|
|
from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
|
@@ -35,6 +36,7 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
|
|
|
from fish_speech.models.text2semantic.llama import (
|
|
|
+ BaseTransformer,
|
|
|
DualARTransformer,
|
|
|
NaiveTransformer,
|
|
|
)
|
|
|
@@ -49,19 +51,19 @@ def multinomial_sample_one_no_sync(
|
|
|
|
|
|
def logits_to_probs(
|
|
|
logits,
|
|
|
+ temperature: torch.Tensor,
|
|
|
+ top_p: torch.Tensor,
|
|
|
+ repetition_penalty: torch.Tensor,
|
|
|
previous_tokens: Optional[torch.Tensor] = None,
|
|
|
- temperature: torch.Tensor = 1.0,
|
|
|
- top_p: torch.Tensor = 1.0,
|
|
|
- repetition_penalty: torch.Tensor = 1.0,
|
|
|
) -> torch.Tensor:
|
|
|
# Apply repetition penalty
|
|
|
if previous_tokens is not None:
|
|
|
previous_tokens = previous_tokens.long()
|
|
|
- score = torch.gather(logits, dim=0, index=previous_tokens)
|
|
|
+ score = torch.gather(logits, dim=-1, index=previous_tokens)
|
|
|
score = torch.where(
|
|
|
score < 0, score * repetition_penalty, score / repetition_penalty
|
|
|
)
|
|
|
- logits.scatter_(dim=0, index=previous_tokens, src=score)
|
|
|
+ logits.scatter_(dim=-1, index=previous_tokens, src=score)
|
|
|
|
|
|
# Apply top-p sampling
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
@@ -69,11 +71,10 @@ def logits_to_probs(
|
|
|
sorted_indices_to_remove = cum_probs > top_p
|
|
|
sorted_indices_to_remove[0] = False # keep at least one option
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
|
|
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
|
|
)
|
|
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
|
-
|
|
|
- logits = logits / max(temperature, 1e-5)
|
|
|
+ logits = logits / torch.clip(temperature, min=1e-5)
|
|
|
|
|
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
|
return probs
|
|
|
@@ -81,11 +82,17 @@ def logits_to_probs(
|
|
|
|
|
|
def sample(
|
|
|
logits,
|
|
|
+ temperature: torch.Tensor,
|
|
|
+ top_p: torch.Tensor,
|
|
|
+ repetition_penalty: torch.Tensor,
|
|
|
previous_tokens: Optional[torch.Tensor] = None,
|
|
|
- **sampling_kwargs,
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
probs = logits_to_probs(
|
|
|
- logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
|
|
|
+ logits=logits[0, -1],
|
|
|
+ temperature=temperature,
|
|
|
+ top_p=top_p,
|
|
|
+ repetition_penalty=repetition_penalty,
|
|
|
+ previous_tokens=previous_tokens,
|
|
|
)
|
|
|
idx_next = multinomial_sample_one_no_sync(probs)
|
|
|
return idx_next, probs
|
|
|
@@ -95,40 +102,35 @@ def decode_one_token_ar(
|
|
|
model: DualARTransformer,
|
|
|
x: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
|
+ temperature: torch.Tensor,
|
|
|
+ top_p: torch.Tensor,
|
|
|
+ repetition_penalty: torch.Tensor,
|
|
|
+ audio_masks: torch.Tensor,
|
|
|
+ audio_parts: torch.Tensor,
|
|
|
previous_tokens: torch.Tensor = None,
|
|
|
- **sampling_kwargs,
|
|
|
) -> torch.Tensor:
|
|
|
- """
|
|
|
- Generate one token using dual autoregressive transformer for text-to-speech.
|
|
|
-
|
|
|
- First generates semantic tokens, then generates acoustic codebook tokens sequentially.
|
|
|
-
|
|
|
- Args:
|
|
|
- x: Input token tensor (1, num_codebooks+1, seq_len)
|
|
|
- input_pos: Position indices for input tokens (seq_len,)
|
|
|
- temperature/top_p/repetition_penalty: Sampling parameters (1, 1)
|
|
|
- previous_tokens: Previous tokens for repetition penalty (1, num_codebooks+1, history_seq_len)
|
|
|
- audio_masks/audio_parts: Audio conditioning tensors (num_codebooks, seq_len)
|
|
|
-
|
|
|
- Returns:
|
|
|
- Generated tokens tensor (num_codebooks+1, 1) - one token per codebook
|
|
|
- """
|
|
|
- x = model.forward_generate(x, input_pos)
|
|
|
-
|
|
|
- sampling_kwargs_main = sampling_kwargs.copy()
|
|
|
+ # print(x, torch.count_nonzero(vq_masks))
|
|
|
+ x = model.forward_generate(
|
|
|
+ x,
|
|
|
+ input_pos,
|
|
|
+ audio_masks=audio_masks,
|
|
|
+ audio_parts=audio_parts,
|
|
|
+ )
|
|
|
+ logits = x.logits # [:, -1:]
|
|
|
+ hidden_states = x.hidden_states # [:, -1:]
|
|
|
|
|
|
codebooks = [
|
|
|
sample(
|
|
|
- x.logits,
|
|
|
+ logits,
|
|
|
+ temperature=temperature,
|
|
|
+ top_p=top_p,
|
|
|
+ repetition_penalty=repetition_penalty,
|
|
|
previous_tokens=(
|
|
|
- previous_tokens[0] if previous_tokens is not None else None
|
|
|
- ), # Disable repetition penalty for the token codebook
|
|
|
- **sampling_kwargs_main,
|
|
|
+ previous_tokens[:, 0] if previous_tokens is not None else None
|
|
|
+ ),
|
|
|
)[0]
|
|
|
]
|
|
|
|
|
|
- hidden_states = x.hidden_states
|
|
|
-
|
|
|
# Cleanup the cache
|
|
|
for layer in model.fast_layers:
|
|
|
layer.attention.kv_cache.k_cache.fill_(0)
|
|
|
@@ -146,22 +148,27 @@ def decode_one_token_ar(
|
|
|
[codebook_idx], device=hidden_states.device, dtype=torch.long
|
|
|
)
|
|
|
logits = model.forward_generate_fast(hidden_states, input_pos)
|
|
|
- chunked_logits = logits[..., :1024]
|
|
|
+
|
|
|
+ short_logits = logits[:, :, :1024]
|
|
|
+
|
|
|
+ # Convert logits to probs
|
|
|
a = sample(
|
|
|
- chunked_logits,
|
|
|
+ short_logits,
|
|
|
+ temperature=temperature,
|
|
|
+ top_p=top_p,
|
|
|
+ repetition_penalty=repetition_penalty,
|
|
|
previous_tokens=(
|
|
|
previous_tokens[codebook_idx + 1]
|
|
|
if previous_tokens is not None
|
|
|
else None
|
|
|
),
|
|
|
- **sampling_kwargs,
|
|
|
)[0]
|
|
|
+
|
|
|
hidden_states = model.fast_embeddings(a)
|
|
|
codebooks.append(a)
|
|
|
|
|
|
- codebooks = torch.stack(codebooks, dim=0)
|
|
|
-
|
|
|
- return codebooks
|
|
|
+ codebooks = torch.stack(codebooks, dim=1)
|
|
|
+ return codebooks.T
|
|
|
|
|
|
|
|
|
def decode_n_tokens(
|
|
|
@@ -169,24 +176,13 @@ def decode_n_tokens(
|
|
|
cur_token: torch.Tensor,
|
|
|
input_pos: torch.Tensor,
|
|
|
num_new_tokens: int,
|
|
|
+ temperature: torch.Tensor,
|
|
|
+ top_p: torch.Tensor,
|
|
|
+ repetition_penalty: torch.Tensor,
|
|
|
+ audio_masks: torch.Tensor,
|
|
|
+ audio_parts: torch.Tensor,
|
|
|
decode_one_token=decode_one_token_ar,
|
|
|
- **sampling_kwargs,
|
|
|
):
|
|
|
- """
|
|
|
- Generate n tokens iteratively using the model.
|
|
|
-
|
|
|
- Args:
|
|
|
- model: The transformer model
|
|
|
- cur_token: Current token tensor of shape (1, num_codebooks+1, seq_len)
|
|
|
- input_pos: Current input position tensor
|
|
|
- num_new_tokens: Number of new tokens to generate
|
|
|
- semantic_ids: List of semantic token IDs
|
|
|
- decode_one_token: Function to decode one token
|
|
|
- **sampling_kwargs: Additional sampling parameters
|
|
|
-
|
|
|
- Returns:
|
|
|
- Generated tokens tensor of shape (num_codebooks+1, generated_len)
|
|
|
- """
|
|
|
previous_tokens = torch.zeros(
|
|
|
(model.config.num_codebooks + 1, model.config.max_seq_len),
|
|
|
dtype=torch.int,
|
|
|
@@ -201,13 +197,19 @@ def decode_n_tokens(
|
|
|
else:
|
|
|
window = previous_tokens[:, i - win_size : i]
|
|
|
|
|
|
- with sdpa_kernel(SDPBackend.MATH):
|
|
|
+ with sdpa_kernel(
|
|
|
+ SDPBackend.MATH
|
|
|
+ ): # Actually better for Inductor to codegen attention here
|
|
|
next_token = decode_one_token(
|
|
|
model=model,
|
|
|
x=cur_token,
|
|
|
input_pos=input_pos,
|
|
|
previous_tokens=window,
|
|
|
- **sampling_kwargs,
|
|
|
+ temperature=temperature,
|
|
|
+ top_p=top_p,
|
|
|
+ repetition_penalty=repetition_penalty,
|
|
|
+ audio_masks=audio_masks,
|
|
|
+ audio_parts=audio_parts,
|
|
|
).clone()
|
|
|
|
|
|
input_pos += 1
|
|
|
@@ -226,33 +228,31 @@ def decode_n_tokens(
|
|
|
@torch.inference_mode()
|
|
|
def generate(
|
|
|
*,
|
|
|
- model: NaiveTransformer,
|
|
|
+ model: BaseTransformer,
|
|
|
prompt: torch.Tensor,
|
|
|
max_new_tokens: int,
|
|
|
+ audio_masks: torch.Tensor,
|
|
|
+ audio_parts: torch.Tensor,
|
|
|
decode_one_token=decode_one_token_ar,
|
|
|
+ num_samples: int = 1,
|
|
|
**sampling_kwargs,
|
|
|
-) -> torch.Tensor:
|
|
|
+):
|
|
|
"""
|
|
|
- Generate tokens from text prompt using the transformer model.
|
|
|
-
|
|
|
- Args:
|
|
|
- model: The transformer model for generation
|
|
|
- prompt: Input token tensor of shape (num_codebooks+1, seq_len)
|
|
|
- max_new_tokens: Maximum number of new tokens to generate
|
|
|
- decode_one_token: Function to decode one token at a time
|
|
|
- **sampling_kwargs: Additional sampling parameters (temperature, top_p, repetition_penalty)
|
|
|
-
|
|
|
- Returns:
|
|
|
- Generated sequence tensor of shape (num_codebooks+1, total_seq_len)
|
|
|
- where total_seq_len = original_seq_len + generated_tokens_len
|
|
|
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
|
|
"""
|
|
|
|
|
|
+ # create an empty tensor of the expected final shape and fill in the current tokens
|
|
|
T = prompt.size(1)
|
|
|
+ prompt = prompt[None].repeat(num_samples, 1, 1)
|
|
|
+
|
|
|
+ if T >= model.config.max_seq_len:
|
|
|
+ raise ValueError(
|
|
|
+ f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
|
|
|
+ )
|
|
|
|
|
|
if max_new_tokens:
|
|
|
if T + max_new_tokens > model.config.max_seq_len:
|
|
|
max_new_tokens = model.config.max_seq_len - T
|
|
|
- logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
|
|
|
|
|
T_new = T + max_new_tokens
|
|
|
else:
|
|
|
@@ -260,23 +260,40 @@ def generate(
|
|
|
max_new_tokens = T_new - T
|
|
|
|
|
|
device, dtype = prompt.device, prompt.dtype
|
|
|
+ with torch.device(device):
|
|
|
+ model.setup_caches(
|
|
|
+ max_batch_size=num_samples,
|
|
|
+ max_seq_len=model.config.max_seq_len,
|
|
|
+ dtype=next(model.parameters()).dtype,
|
|
|
+ )
|
|
|
|
|
|
codebook_dim = 1 + model.config.num_codebooks
|
|
|
+ input_pos = torch.arange(0, T, device=device)
|
|
|
empty = torch.empty(
|
|
|
(codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
|
|
|
)
|
|
|
empty[:, :T] = prompt
|
|
|
seq = empty
|
|
|
- input_pos = torch.arange(0, T, device=device)
|
|
|
|
|
|
- # Use non-accelerated version for now, to avoid compilation overhead
|
|
|
+ temperature = torch.tensor(
|
|
|
+ sampling_kwargs["temperature"], device=device, dtype=torch.bfloat16
|
|
|
+ )
|
|
|
+ top_p = torch.tensor(sampling_kwargs["top_p"], device=device, dtype=torch.bfloat16)
|
|
|
+ repetition_penalty = torch.tensor(
|
|
|
+ sampling_kwargs["repetition_penalty"], device=device, dtype=torch.bfloat16
|
|
|
+ )
|
|
|
+
|
|
|
prefill_decode = decode_one_token_ar
|
|
|
|
|
|
first_token = prefill_decode(
|
|
|
model,
|
|
|
prompt.view(1, codebook_dim, -1),
|
|
|
input_pos,
|
|
|
- **sampling_kwargs,
|
|
|
+ temperature,
|
|
|
+ top_p,
|
|
|
+ repetition_penalty,
|
|
|
+ audio_masks,
|
|
|
+ audio_parts,
|
|
|
)
|
|
|
seq[:, T : T + 1] = first_token
|
|
|
|
|
|
@@ -286,12 +303,15 @@ def generate(
|
|
|
first_token.view(1, codebook_dim, -1),
|
|
|
input_pos,
|
|
|
max_new_tokens - 1,
|
|
|
+ temperature=temperature,
|
|
|
+ top_p=top_p,
|
|
|
+ repetition_penalty=repetition_penalty,
|
|
|
+ audio_masks=audio_masks,
|
|
|
+ audio_parts=audio_parts,
|
|
|
decode_one_token=decode_one_token,
|
|
|
- **sampling_kwargs,
|
|
|
)
|
|
|
seq = seq[:, : T + 1 + x.size(1)]
|
|
|
seq[:, T + 1 :] = x
|
|
|
-
|
|
|
return seq
|
|
|
|
|
|
|
|
|
@@ -303,17 +323,26 @@ def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
|
|
|
if isinstance(model, DualARTransformer):
|
|
|
decode_one_token = decode_one_token_ar
|
|
|
+ prefill_n_tokens = decode_one_token_ar
|
|
|
logger.info("Using DualARTransformer")
|
|
|
else:
|
|
|
- raise ValueError("Model is not a DualARTransformer")
|
|
|
+ raise ValueError("Unsupported model type")
|
|
|
+
|
|
|
+ # Initialize cache
|
|
|
+ with torch.device(device):
|
|
|
+ model.setup_caches(
|
|
|
+ max_batch_size=1,
|
|
|
+ max_seq_len=model.config.max_seq_len,
|
|
|
+ dtype=next(model.parameters()).dtype,
|
|
|
+ )
|
|
|
|
|
|
if compile:
|
|
|
logger.info("Compiling function...")
|
|
|
decode_one_token = torch.compile(
|
|
|
decode_one_token,
|
|
|
+ # mode="max-autotune-no-cudagraphs",
|
|
|
+ mode="reduce-overhead",
|
|
|
fullgraph=True,
|
|
|
- backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
|
|
- mode="reduce-overhead" if torch.cuda.is_available() else None,
|
|
|
)
|
|
|
|
|
|
return model.eval(), decode_one_token
|
|
|
@@ -362,9 +391,7 @@ def generate_long(
|
|
|
tokenizer = model.tokenizer
|
|
|
base_content_sequence = ContentSequence(modality="interleave")
|
|
|
|
|
|
- texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
|
max_length = model.config.max_seq_len
|
|
|
-
|
|
|
if use_prompt:
|
|
|
for t, c in zip(prompt_text, prompt_tokens):
|
|
|
base_content_sequence.append(
|
|
|
@@ -373,26 +400,24 @@ def generate_long(
|
|
|
VQPart(codes=c),
|
|
|
],
|
|
|
add_end=True,
|
|
|
+ speaker=0,
|
|
|
)
|
|
|
+ base_content_sequence.append(
|
|
|
+ [
|
|
|
+ TextPart(text=text),
|
|
|
+ ],
|
|
|
+ add_end=False,
|
|
|
+ speaker=0,
|
|
|
+ )
|
|
|
|
|
|
- encoded_prompts = base_content_sequence.encode_for_inference(
|
|
|
+ encoded, audio_masks, audio_parts = base_content_sequence.encode_for_inference(
|
|
|
tokenizer, num_codebooks=model.config.num_codebooks
|
|
|
)
|
|
|
- if encoded_prompts.size(1) > max_length - 2048:
|
|
|
- raise ValueError(
|
|
|
- f"Prompt is too long: {encoded_prompts.size(1)} > {max_length - 2048}"
|
|
|
- )
|
|
|
+ if encoded.size(1) > max_length - 2048:
|
|
|
+ raise ValueError(f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}")
|
|
|
|
|
|
- encoded = []
|
|
|
- for text in texts:
|
|
|
- content_sequence = ContentSequence(modality="text")
|
|
|
- content_sequence.append(TextPart(text=text))
|
|
|
- encoded.append(
|
|
|
- content_sequence.encode_for_inference(
|
|
|
- tokenizer, num_codebooks=model.config.num_codebooks
|
|
|
- )
|
|
|
- )
|
|
|
- logger.info(f"Encoded text: {text}")
|
|
|
+ encoded = encoded.to(device=device)
|
|
|
+ logger.info(f"Encoded text: {text}")
|
|
|
|
|
|
# Move temperature, top_p, repetition_penalty to device
|
|
|
# This is important so that changing params doesn't trigger recompile
|
|
|
@@ -408,70 +433,53 @@ def generate_long(
|
|
|
|
|
|
global_encoded = []
|
|
|
seg_idx = 0
|
|
|
+ prompt_length = encoded.size(1)
|
|
|
+
|
|
|
+ t0 = time.perf_counter()
|
|
|
+ y = generate(
|
|
|
+ model=model,
|
|
|
+ prompt=encoded,
|
|
|
+ max_new_tokens=max_new_tokens,
|
|
|
+ audio_masks=audio_masks,
|
|
|
+ audio_parts=audio_parts,
|
|
|
+ decode_one_token=decode_one_token,
|
|
|
+ temperature=temperature,
|
|
|
+ top_p=top_p,
|
|
|
+ repetition_penalty=repetition_penalty,
|
|
|
+ )
|
|
|
|
|
|
- while seg_idx < len(encoded):
|
|
|
- logger.info(
|
|
|
- f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
|
|
- )
|
|
|
-
|
|
|
- seg = encoded[seg_idx]
|
|
|
- global_encoded.append(seg)
|
|
|
-
|
|
|
- if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
|
|
|
- cat_encoded = torch.cat(
|
|
|
- [encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
|
|
|
- )
|
|
|
- else:
|
|
|
- cat_encoded = torch.cat([encoded_prompts, seg], dim=1)
|
|
|
-
|
|
|
- cat_encoded = cat_encoded.to(device=device)
|
|
|
- prompt_length = cat_encoded.size(1)
|
|
|
-
|
|
|
- t0 = time.perf_counter()
|
|
|
- y = generate(
|
|
|
- model=model,
|
|
|
- prompt=cat_encoded,
|
|
|
- max_new_tokens=max_new_tokens,
|
|
|
- decode_one_token=decode_one_token,
|
|
|
- temperature=temperature,
|
|
|
- top_p=top_p,
|
|
|
- repetition_penalty=repetition_penalty,
|
|
|
- )
|
|
|
+ if sample_idx == 0 and seg_idx == 0 and compile:
|
|
|
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
|
|
|
|
|
- if sample_idx == 0 and seg_idx == 0 and compile:
|
|
|
- logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.synchronize()
|
|
|
|
|
|
- if torch.cuda.is_available():
|
|
|
- torch.cuda.synchronize()
|
|
|
+ t = time.perf_counter() - t0
|
|
|
|
|
|
- t = time.perf_counter() - t0
|
|
|
+ tokens_generated = y.size(1) - prompt_length
|
|
|
+ tokens_sec = tokens_generated / t
|
|
|
+ logger.info(
|
|
|
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
|
|
|
+ )
|
|
|
+ logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
|
|
|
|
|
|
- tokens_generated = y.size(1) - prompt_length
|
|
|
- tokens_sec = tokens_generated / t
|
|
|
- logger.info(
|
|
|
- f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
|
|
|
- )
|
|
|
+ if torch.cuda.is_available():
|
|
|
logger.info(
|
|
|
- f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
|
|
|
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
|
|
)
|
|
|
|
|
|
- if torch.cuda.is_available():
|
|
|
- logger.info(
|
|
|
- f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
|
|
- )
|
|
|
-
|
|
|
- # Put the generated tokens
|
|
|
- # since there is <im_end>, we remove last token
|
|
|
- codes = y[1:, prompt_length:-1].clone()
|
|
|
- assert (codes >= 0).all(), f"Negative code found"
|
|
|
+ # Put the generated tokens
|
|
|
+ # since there is <im_end>, we remove last token
|
|
|
+ codes = y[1:, prompt_length:-1].clone()
|
|
|
+ assert (codes >= 0).all(), f"Negative code found"
|
|
|
|
|
|
- decoded = y[:, prompt_length:].clone()
|
|
|
- # But for global encoding, we should keep the <im_end> token
|
|
|
+ decoded = y[:, prompt_length:].clone()
|
|
|
+ # But for global encoding, we should keep the <im_end> token
|
|
|
|
|
|
- global_encoded.append(decoded.cpu())
|
|
|
- assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
- yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
|
|
|
- seg_idx += 1
|
|
|
+ global_encoded.append(decoded.cpu())
|
|
|
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
+ yield GenerateResponse(action="sample", codes=codes, text=text)
|
|
|
+ seg_idx += 1
|
|
|
|
|
|
# This indicates the end of the current sample
|
|
|
yield GenerateResponse(action="next")
|
|
|
@@ -526,6 +534,7 @@ def launch_thread_safe_queue(
|
|
|
WrappedGenerateResponse(status="success", response=chunk)
|
|
|
)
|
|
|
except Exception as e:
|
|
|
+ logger.error(traceback.format_exc())
|
|
|
response_queue.put(WrappedGenerateResponse(status="error", response=e))
|
|
|
|
|
|
threading.Thread(target=worker, daemon=True).start()
|