|
@@ -1,8 +1,10 @@
|
|
|
import os
|
|
import os
|
|
|
import queue
|
|
import queue
|
|
|
|
|
+import re
|
|
|
import threading
|
|
import threading
|
|
|
import time
|
|
import time
|
|
|
import traceback
|
|
import traceback
|
|
|
|
|
+from copy import deepcopy
|
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from typing import Callable, Literal, Optional, Tuple, Union
|
|
from typing import Callable, Literal, Optional, Tuple, Union
|
|
@@ -13,13 +15,12 @@ import torch
|
|
|
import torch._inductor.config
|
|
import torch._inductor.config
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
-from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
from fish_speech.content_sequence import (
|
|
from fish_speech.content_sequence import (
|
|
|
- ContentSequence,
|
|
|
|
|
TextPart,
|
|
TextPart,
|
|
|
VQPart,
|
|
VQPart,
|
|
|
)
|
|
)
|
|
|
|
|
+from fish_speech.conversation import Conversation, Message
|
|
|
from fish_speech.tokenizer import IM_END_TOKEN
|
|
from fish_speech.tokenizer import IM_END_TOKEN
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
@@ -27,7 +28,6 @@ torch._inductor.config.coordinate_descent_tuning = True
|
|
|
torch._inductor.config.triton.unique_kernel_names = True
|
|
torch._inductor.config.triton.unique_kernel_names = True
|
|
|
|
|
|
|
|
if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
|
- # Experimental feature to reduce compilation times, will be on by default in future
|
|
|
|
|
torch._inductor.config.fx_graph_cache = True
|
|
torch._inductor.config.fx_graph_cache = True
|
|
|
|
|
|
|
|
|
|
|
|
@@ -47,26 +47,23 @@ def multinomial_sample_one_no_sync(
|
|
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+RAS_WIN_SIZE = 10 # window for Repetition Aware Sampling
|
|
|
|
|
+RAS_HIGH_TEMP = 1.0
|
|
|
|
|
+RAS_HIGH_TOP_P = 0.9
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def logits_to_probs(
|
|
def logits_to_probs(
|
|
|
logits,
|
|
logits,
|
|
|
temperature: torch.Tensor,
|
|
temperature: torch.Tensor,
|
|
|
top_p: torch.Tensor,
|
|
top_p: torch.Tensor,
|
|
|
- repetition_penalty: torch.Tensor,
|
|
|
|
|
- previous_tokens: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
+ top_k: torch.Tensor,
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
- # Apply repetition penalty
|
|
|
|
|
- if previous_tokens is not None:
|
|
|
|
|
- previous_tokens = previous_tokens.long()
|
|
|
|
|
- score = torch.gather(logits, dim=-1, index=previous_tokens)
|
|
|
|
|
- score = torch.where(
|
|
|
|
|
- score < 0, score * repetition_penalty, score / repetition_penalty
|
|
|
|
|
- )
|
|
|
|
|
- logits.scatter_(dim=-1, index=previous_tokens, src=score)
|
|
|
|
|
-
|
|
|
|
|
- # Apply top-p sampling
|
|
|
|
|
|
|
+ # Sort and compute top-p mask
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
sorted_indices_to_remove = cum_probs > top_p
|
|
sorted_indices_to_remove = cum_probs > top_p
|
|
|
|
|
+ # top-k mask
|
|
|
|
|
+ sorted_indices_to_remove[top_k:] = True
|
|
|
sorted_indices_to_remove[0] = False # keep at least one option
|
|
sorted_indices_to_remove[0] = False # keep at least one option
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
|
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
|
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
|
@@ -82,15 +79,13 @@ def sample(
|
|
|
logits,
|
|
logits,
|
|
|
temperature: torch.Tensor,
|
|
temperature: torch.Tensor,
|
|
|
top_p: torch.Tensor,
|
|
top_p: torch.Tensor,
|
|
|
- repetition_penalty: torch.Tensor,
|
|
|
|
|
- previous_tokens: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
+ top_k: int,
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
probs = logits_to_probs(
|
|
probs = logits_to_probs(
|
|
|
logits=logits[0, -1],
|
|
logits=logits[0, -1],
|
|
|
temperature=temperature,
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
top_p=top_p,
|
|
|
- repetition_penalty=repetition_penalty,
|
|
|
|
|
- previous_tokens=previous_tokens,
|
|
|
|
|
|
|
+ top_k=top_k,
|
|
|
)
|
|
)
|
|
|
idx_next = multinomial_sample_one_no_sync(probs)
|
|
idx_next = multinomial_sample_one_no_sync(probs)
|
|
|
return idx_next, probs
|
|
return idx_next, probs
|
|
@@ -102,32 +97,44 @@ def decode_one_token_ar(
|
|
|
input_pos: torch.Tensor,
|
|
input_pos: torch.Tensor,
|
|
|
temperature: torch.Tensor,
|
|
temperature: torch.Tensor,
|
|
|
top_p: torch.Tensor,
|
|
top_p: torch.Tensor,
|
|
|
- repetition_penalty: torch.Tensor,
|
|
|
|
|
|
|
+ top_k: int,
|
|
|
|
|
+ semantic_logit_bias: torch.Tensor,
|
|
|
audio_masks: torch.Tensor,
|
|
audio_masks: torch.Tensor,
|
|
|
audio_parts: torch.Tensor,
|
|
audio_parts: torch.Tensor,
|
|
|
previous_tokens: Optional[torch.Tensor] = None,
|
|
previous_tokens: Optional[torch.Tensor] = None,
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
- # print(x, torch.count_nonzero(vq_masks))
|
|
|
|
|
forward_result = model.forward_generate(
|
|
forward_result = model.forward_generate(
|
|
|
x,
|
|
x,
|
|
|
input_pos,
|
|
input_pos,
|
|
|
audio_masks=audio_masks,
|
|
audio_masks=audio_masks,
|
|
|
audio_parts=audio_parts,
|
|
audio_parts=audio_parts,
|
|
|
)
|
|
)
|
|
|
- logits = forward_result.logits # [:, -1:]
|
|
|
|
|
- hidden_states = forward_result.hidden_states # [:, -1:]
|
|
|
|
|
|
|
+ logits = forward_result.logits # (1, 1, vocab_size)
|
|
|
|
|
+ hidden_states = forward_result.hidden_states
|
|
|
|
|
|
|
|
- codebooks = [
|
|
|
|
|
- sample(
|
|
|
|
|
- logits,
|
|
|
|
|
- temperature=temperature,
|
|
|
|
|
- top_p=top_p,
|
|
|
|
|
- repetition_penalty=repetition_penalty,
|
|
|
|
|
- previous_tokens=(
|
|
|
|
|
- previous_tokens[:, 0] if previous_tokens is not None else None
|
|
|
|
|
- ),
|
|
|
|
|
- )[0]
|
|
|
|
|
- ]
|
|
|
|
|
|
|
+ # Apply constrained decoding: only allow semantic tokens + im_end
|
|
|
|
|
+ biased_logits = logits + semantic_logit_bias
|
|
|
|
|
+
|
|
|
|
|
+ # Normal sample
|
|
|
|
|
+ main_token_normal = sample(biased_logits, temperature=temperature, top_p=top_p, top_k=top_k)[0]
|
|
|
|
|
+
|
|
|
|
|
+ # RAS: also sample with high temp to use as fallback if token repeats
|
|
|
|
|
+ high_temp = torch.tensor(RAS_HIGH_TEMP, device=temperature.device, dtype=temperature.dtype)
|
|
|
|
|
+ high_top_p = torch.tensor(RAS_HIGH_TOP_P, device=top_p.device, dtype=top_p.dtype)
|
|
|
|
|
+ main_token_high = sample(biased_logits, temperature=high_temp, top_p=high_top_p, top_k=top_k)[0]
|
|
|
|
|
+
|
|
|
|
|
+ # Use high-temp sample if: token is semantic AND token is in previous window
|
|
|
|
|
+ if previous_tokens is not None:
|
|
|
|
|
+ in_window = (previous_tokens[0] == main_token_normal).any()
|
|
|
|
|
+ # Use tensor ops (&, torch.where) instead of Python (and, if) — torch.compile requires no data-dependent branching
|
|
|
|
|
+ is_semantic = (
|
|
|
|
|
+ (main_token_normal >= model.config.semantic_begin_id)
|
|
|
|
|
+ & (main_token_normal <= model.config.semantic_end_id)
|
|
|
|
|
+ )
|
|
|
|
|
+ should_use_high = in_window & is_semantic
|
|
|
|
|
+ main_token_normal = torch.where(should_use_high, main_token_high, main_token_normal)
|
|
|
|
|
+
|
|
|
|
|
+ codebooks = [main_token_normal]
|
|
|
|
|
|
|
|
# Only clear cache for fast_layers, avoid clearing main model cache
|
|
# Only clear cache for fast_layers, avoid clearing main model cache
|
|
|
for layer in model.fast_layers:
|
|
for layer in model.fast_layers:
|
|
@@ -137,8 +144,11 @@ def decode_one_token_ar(
|
|
|
|
|
|
|
|
input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
|
|
input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
|
|
|
model.forward_generate_fast(hidden_states, input_pos)
|
|
model.forward_generate_fast(hidden_states, input_pos)
|
|
|
- a = codebooks[0] - model.tokenizer.semantic_begin_id
|
|
|
|
|
|
|
+
|
|
|
|
|
+ # [MODIFIED] Access config instead of tokenizer
|
|
|
|
|
+ a = codebooks[0] - model.config.semantic_begin_id
|
|
|
a[a < 0] = 0
|
|
a[a < 0] = 0
|
|
|
|
|
+ a[a >= model.config.codebook_size] = 0
|
|
|
hidden_states = model.fast_embeddings(a)
|
|
hidden_states = model.fast_embeddings(a)
|
|
|
codebooks.append(a)
|
|
codebooks.append(a)
|
|
|
|
|
|
|
@@ -148,19 +158,14 @@ def decode_one_token_ar(
|
|
|
)
|
|
)
|
|
|
logits = model.forward_generate_fast(hidden_states, input_pos)
|
|
logits = model.forward_generate_fast(hidden_states, input_pos)
|
|
|
|
|
|
|
|
- short_logits = logits[:, :, :1024]
|
|
|
|
|
|
|
+ short_logits = logits # DualAR predicts config.codebook_size number of tokens
|
|
|
|
|
|
|
|
- # Convert logits to probs
|
|
|
|
|
|
|
+ # Convert logits to probs (no constrain for fast codebooks)
|
|
|
a = sample(
|
|
a = sample(
|
|
|
short_logits,
|
|
short_logits,
|
|
|
temperature=temperature,
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
top_p=top_p,
|
|
|
- repetition_penalty=repetition_penalty,
|
|
|
|
|
- previous_tokens=(
|
|
|
|
|
- previous_tokens[codebook_idx + 1]
|
|
|
|
|
- if previous_tokens is not None
|
|
|
|
|
- else None
|
|
|
|
|
- ),
|
|
|
|
|
|
|
+ top_k=top_k,
|
|
|
)[0]
|
|
)[0]
|
|
|
|
|
|
|
|
hidden_states = model.fast_embeddings(a)
|
|
hidden_states = model.fast_embeddings(a)
|
|
@@ -181,53 +186,52 @@ def decode_n_tokens(
|
|
|
num_new_tokens: int,
|
|
num_new_tokens: int,
|
|
|
temperature: torch.Tensor,
|
|
temperature: torch.Tensor,
|
|
|
top_p: torch.Tensor,
|
|
top_p: torch.Tensor,
|
|
|
- repetition_penalty: torch.Tensor,
|
|
|
|
|
|
|
+ top_k: int,
|
|
|
|
|
+ semantic_logit_bias: torch.Tensor,
|
|
|
audio_masks: torch.Tensor,
|
|
audio_masks: torch.Tensor,
|
|
|
audio_parts: torch.Tensor,
|
|
audio_parts: torch.Tensor,
|
|
|
decode_one_token=decode_one_token_ar,
|
|
decode_one_token=decode_one_token_ar,
|
|
|
):
|
|
):
|
|
|
|
|
+ # Rolling window for RAS (Repetition Aware Sampling)
|
|
|
previous_tokens = torch.zeros(
|
|
previous_tokens = torch.zeros(
|
|
|
- (model.config.num_codebooks + 1, model.config.max_seq_len),
|
|
|
|
|
|
|
+ (model.config.num_codebooks + 1, RAS_WIN_SIZE),
|
|
|
dtype=torch.int,
|
|
dtype=torch.int,
|
|
|
device=cur_token.device,
|
|
device=cur_token.device,
|
|
|
)
|
|
)
|
|
|
|
|
+ # Accumulate all generated tokens (the actual output)
|
|
|
|
|
+ new_tokens = []
|
|
|
|
|
+
|
|
|
|
|
+ # [MODIFIED] Pre-fetch ID for efficiency loop
|
|
|
|
|
+ im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
|
|
|
|
|
|
|
|
for i in tqdm(range(num_new_tokens)):
|
|
for i in tqdm(range(num_new_tokens)):
|
|
|
- # We need to get windowed repeat penalty
|
|
|
|
|
- win_size = 16
|
|
|
|
|
- if i < win_size:
|
|
|
|
|
- window = previous_tokens[:, :win_size]
|
|
|
|
|
- else:
|
|
|
|
|
- window = previous_tokens[:, i - win_size : i]
|
|
|
|
|
-
|
|
|
|
|
- with sdpa_kernel(
|
|
|
|
|
- SDPBackend.MATH
|
|
|
|
|
- ): # Actually better for Inductor to codegen attention here
|
|
|
|
|
|
|
+ with sdpa_kernel(SDPBackend.MATH):
|
|
|
next_token = decode_one_token(
|
|
next_token = decode_one_token(
|
|
|
model=model,
|
|
model=model,
|
|
|
x=cur_token,
|
|
x=cur_token,
|
|
|
input_pos=input_pos,
|
|
input_pos=input_pos,
|
|
|
- previous_tokens=window,
|
|
|
|
|
|
|
+ previous_tokens=previous_tokens,
|
|
|
temperature=temperature,
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
top_p=top_p,
|
|
|
- repetition_penalty=repetition_penalty,
|
|
|
|
|
|
|
+ top_k=top_k,
|
|
|
|
|
+ semantic_logit_bias=semantic_logit_bias,
|
|
|
audio_masks=audio_masks,
|
|
audio_masks=audio_masks,
|
|
|
audio_parts=audio_parts,
|
|
audio_parts=audio_parts,
|
|
|
).clone()
|
|
).clone()
|
|
|
|
|
|
|
|
input_pos += 1
|
|
input_pos += 1
|
|
|
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
|
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
|
|
- previous_tokens[:, i : i + 1] = next_token.view(
|
|
|
|
|
- model.config.num_codebooks + 1, -1
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # Roll RAS window left and insert new token at end
|
|
|
|
|
+ previous_tokens = previous_tokens.roll(-1, dims=1)
|
|
|
|
|
+ previous_tokens[:, -1] = next_token.view(model.config.num_codebooks + 1, -1)[:, 0]
|
|
|
|
|
+ new_tokens.append(next_token)
|
|
|
|
|
|
|
|
- if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
|
|
|
|
|
|
|
+ if cur_token[0, 0, -1] == im_end_id:
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
- # Only clean up the large tensor
|
|
|
|
|
del cur_token
|
|
del cur_token
|
|
|
|
|
|
|
|
- return previous_tokens[:, : i + 1]
|
|
|
|
|
|
|
+ return torch.cat(new_tokens, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.no_grad()
|
|
@@ -265,7 +269,8 @@ def generate(
|
|
|
T_new = model.config.max_seq_len
|
|
T_new = model.config.max_seq_len
|
|
|
max_new_tokens = T_new - T
|
|
max_new_tokens = T_new - T
|
|
|
|
|
|
|
|
- device, dtype = prompt.device, prompt.dtype
|
|
|
|
|
|
|
+ device = prompt.device
|
|
|
|
|
+ dtype = next(model.parameters()).dtype # model weight dtype (bfloat16), NOT prompt dtype (int32)
|
|
|
|
|
|
|
|
# Critical fix: Only set up cache on first run or when necessary
|
|
# Critical fix: Only set up cache on first run or when necessary
|
|
|
if not hasattr(model, "_cache_setup_done") or not model._cache_setup_done:
|
|
if not hasattr(model, "_cache_setup_done") or not model._cache_setup_done:
|
|
@@ -282,35 +287,31 @@ def generate(
|
|
|
# Create new tensor each time, but try to reuse memory
|
|
# Create new tensor each time, but try to reuse memory
|
|
|
input_pos = torch.arange(0, T, device=device, dtype=torch.long)
|
|
input_pos = torch.arange(0, T, device=device, dtype=torch.long)
|
|
|
empty = torch.empty(
|
|
empty = torch.empty(
|
|
|
- (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
|
|
|
|
|
|
|
+ (codebook_dim, model.config.max_seq_len), dtype=prompt.dtype, device=device
|
|
|
)
|
|
)
|
|
|
empty[:, :T] = prompt
|
|
empty[:, :T] = prompt
|
|
|
seq = empty
|
|
seq = empty
|
|
|
|
|
|
|
|
- # Use pre-created fixed parameter tensors
|
|
|
|
|
- temperature = getattr(
|
|
|
|
|
- model, "fixed_temperature", torch.tensor(0.8, device=device, dtype=torch.float)
|
|
|
|
|
- )
|
|
|
|
|
- top_p = getattr(
|
|
|
|
|
- model, "fixed_top_p", torch.tensor(0.8, device=device, dtype=torch.float)
|
|
|
|
|
- )
|
|
|
|
|
- repetition_penalty = getattr(
|
|
|
|
|
- model,
|
|
|
|
|
- "fixed_repetition_penalty",
|
|
|
|
|
- torch.tensor(1.1, device=device, dtype=torch.float),
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ temp_val = sampling_kwargs.get("temperature", 1.0)
|
|
|
|
|
+ top_p_val = sampling_kwargs.get("top_p", 0.9)
|
|
|
|
|
+ top_k_val = sampling_kwargs.get("top_k", 30)
|
|
|
|
|
|
|
|
- # If different parameter values are needed, directly modify existing tensors
|
|
|
|
|
- temp_val = sampling_kwargs.get("temperature", 0.7)
|
|
|
|
|
- top_p_val = sampling_kwargs.get("top_p", 0.7)
|
|
|
|
|
- rep_val = sampling_kwargs.get("repetition_penalty", 1.5)
|
|
|
|
|
|
|
+ temperature = torch.tensor(temp_val, device=device, dtype=dtype)
|
|
|
|
|
+ top_p = torch.tensor(top_p_val, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
- if abs(temperature.item() - temp_val) > 1e-6:
|
|
|
|
|
- temperature.fill_(temp_val)
|
|
|
|
|
- if abs(top_p.item() - top_p_val) > 1e-6:
|
|
|
|
|
- top_p.fill_(top_p_val)
|
|
|
|
|
- if abs(repetition_penalty.item() - rep_val) > 1e-6:
|
|
|
|
|
- repetition_penalty.fill_(rep_val)
|
|
|
|
|
|
|
+ # Build semantic logit bias: 0 for semantic tokens + im_end, -inf for all others
|
|
|
|
|
+ vocab_size = model.config.vocab_size
|
|
|
|
|
+ semantic_logit_bias = torch.full(
|
|
|
|
|
+ (1, 1, vocab_size), float("-inf"), device=device, dtype=dtype
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # [MODIFIED] Use config for semantic range
|
|
|
|
|
+ semantic_logit_bias[
|
|
|
|
|
+ 0, 0, model.config.semantic_begin_id : model.config.semantic_end_id + 1
|
|
|
|
|
+ ] = 0.0
|
|
|
|
|
+
|
|
|
|
|
+ # [MODIFIED] Use tokenizer.get_token_id (Wrapper method)
|
|
|
|
|
+ semantic_logit_bias[0, 0, model.tokenizer.get_token_id(IM_END_TOKEN)] = 0.0
|
|
|
|
|
|
|
|
prefill_decode = decode_one_token_ar
|
|
prefill_decode = decode_one_token_ar
|
|
|
|
|
|
|
@@ -320,7 +321,8 @@ def generate(
|
|
|
input_pos,
|
|
input_pos,
|
|
|
temperature,
|
|
temperature,
|
|
|
top_p,
|
|
top_p,
|
|
|
- repetition_penalty,
|
|
|
|
|
|
|
+ top_k_val,
|
|
|
|
|
+ semantic_logit_bias,
|
|
|
audio_masks,
|
|
audio_masks,
|
|
|
audio_parts,
|
|
audio_parts,
|
|
|
)
|
|
)
|
|
@@ -336,7 +338,8 @@ def generate(
|
|
|
max_new_tokens - 1,
|
|
max_new_tokens - 1,
|
|
|
temperature=temperature,
|
|
temperature=temperature,
|
|
|
top_p=top_p,
|
|
top_p=top_p,
|
|
|
- repetition_penalty=repetition_penalty,
|
|
|
|
|
|
|
+ top_k=top_k_val,
|
|
|
|
|
+ semantic_logit_bias=semantic_logit_bias,
|
|
|
audio_masks=audio_masks,
|
|
audio_masks=audio_masks,
|
|
|
audio_parts=audio_parts,
|
|
audio_parts=audio_parts,
|
|
|
decode_one_token=decode_one_token,
|
|
decode_one_token=decode_one_token,
|
|
@@ -358,7 +361,7 @@ def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
|
|
|
|
|
if isinstance(model, DualARTransformer):
|
|
if isinstance(model, DualARTransformer):
|
|
|
decode_one_token = decode_one_token_ar
|
|
decode_one_token = decode_one_token_ar
|
|
|
- prefill_n_tokens = decode_one_token_ar
|
|
|
|
|
|
|
+ # prefill_n_tokens = decode_one_token_ar
|
|
|
logger.info("Using DualARTransformer")
|
|
logger.info("Using DualARTransformer")
|
|
|
else:
|
|
else:
|
|
|
raise ValueError("Unsupported model type")
|
|
raise ValueError("Unsupported model type")
|
|
@@ -383,6 +386,60 @@ def init_model(checkpoint_path, device, precision, compile=False):
|
|
|
return model.eval(), decode_one_token
|
|
return model.eval(), decode_one_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+@torch.inference_mode()
|
|
|
|
|
+def load_codec_model(codec_checkpoint_path, device, precision=torch.bfloat16):
|
|
|
|
|
+ """Load the DAC codec model for audio encoding/decoding."""
|
|
|
|
|
+ from hydra.utils import instantiate
|
|
|
|
|
+ from omegaconf import OmegaConf
|
|
|
|
|
+
|
|
|
|
|
+ config_path = Path(__file__).parent.parent.parent / "configs" / "modded_dac_vq.yaml"
|
|
|
|
|
+ cfg = OmegaConf.load(str(config_path))
|
|
|
|
|
+ codec = instantiate(cfg)
|
|
|
|
|
+
|
|
|
|
|
+ state_dict = torch.load(codec_checkpoint_path, map_location="cpu")
|
|
|
|
|
+ if "state_dict" in state_dict:
|
|
|
|
|
+ state_dict = state_dict["state_dict"]
|
|
|
|
|
+ if any("generator" in k for k in state_dict):
|
|
|
|
|
+ state_dict = {
|
|
|
|
|
+ k.replace("generator.", ""): v
|
|
|
|
|
+ for k, v in state_dict.items()
|
|
|
|
|
+ if "generator." in k
|
|
|
|
|
+ }
|
|
|
|
|
+ codec.load_state_dict(state_dict, strict=False)
|
|
|
|
|
+ codec.eval()
|
|
|
|
|
+ codec.to(device=device, dtype=precision)
|
|
|
|
|
+ return codec
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@torch.inference_mode()
|
|
|
|
|
+def encode_audio(audio_path, codec, device):
|
|
|
|
|
+ """Encode an audio file to VQ codes."""
|
|
|
|
|
+ import torchaudio
|
|
|
|
|
+
|
|
|
|
|
+ wav, sr = torchaudio.load(str(audio_path))
|
|
|
|
|
+ if wav.shape[0] > 1:
|
|
|
|
|
+ wav = wav.mean(dim=0, keepdim=True)
|
|
|
|
|
+ wav = torchaudio.functional.resample(
|
|
|
|
|
+ wav.to(device), sr, codec.sample_rate
|
|
|
|
|
+ )[0]
|
|
|
|
|
+
|
|
|
|
|
+ # Match codec model dtype (e.g. bfloat16)
|
|
|
|
|
+ model_dtype = next(codec.parameters()).dtype
|
|
|
|
|
+ audios = wav[None, None].to(dtype=model_dtype) # (1, 1, T)
|
|
|
|
|
+ audio_lengths = torch.tensor([len(wav)], device=device, dtype=torch.long)
|
|
|
|
|
+
|
|
|
|
|
+ indices, feature_lengths = codec.encode(audios, audio_lengths)
|
|
|
|
|
+ return indices[0, :, : feature_lengths[0]] # (num_codebooks, T)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@torch.inference_mode()
|
|
|
|
|
+def decode_to_audio(codes, codec):
|
|
|
|
|
+ """Decode VQ codes to audio waveform."""
|
|
|
|
|
+ # codes: (num_codebooks, T) -> (1, num_codebooks, T)
|
|
|
|
|
+ audio = codec.from_indices(codes[None])
|
|
|
|
|
+ return audio[0, 0] # (T,) mono waveform
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
@dataclass
|
|
@dataclass
|
|
|
class GenerateResponse:
|
|
class GenerateResponse:
|
|
|
action: Literal["sample", "next"]
|
|
action: Literal["sample", "next"]
|
|
@@ -390,6 +447,75 @@ class GenerateResponse:
|
|
|
text: Optional[str] = None
|
|
text: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def split_text_by_speaker(text: str) -> list[str]:
|
|
|
|
|
+ """
|
|
|
|
|
+ Split text into turns based on <|speaker:X|> tags.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ text: The full text with speaker tags
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ List of speaker turns, each starting with <|speaker:X|>
|
|
|
|
|
+ """
|
|
|
|
|
+ pattern = r"(<\|speaker:\d+\|>)"
|
|
|
|
|
+ parts = re.split(pattern, text)
|
|
|
|
|
+
|
|
|
|
|
+ turns = []
|
|
|
|
|
+ i = 0
|
|
|
|
|
+ while i < len(parts):
|
|
|
|
|
+ part = parts[i].strip()
|
|
|
|
|
+ if re.match(pattern, part):
|
|
|
|
|
+ if i + 1 < len(parts):
|
|
|
|
|
+ turn = part + parts[i + 1]
|
|
|
|
|
+ turns.append(turn.strip())
|
|
|
|
|
+ i += 2
|
|
|
|
|
+ else:
|
|
|
|
|
+ turns.append(part)
|
|
|
|
|
+ i += 1
|
|
|
|
|
+ else:
|
|
|
|
|
+ i += 1
|
|
|
|
|
+
|
|
|
|
|
+ return turns
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def group_turns_into_batches(
|
|
|
|
|
+ turns: list[str], max_speakers: int = 3, max_bytes: int = 300
|
|
|
|
|
+) -> list[str]:
|
|
|
|
|
+ """
|
|
|
|
|
+ Group turns into batches based on speaker count or byte limit.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ turns: List of speaker turns
|
|
|
|
|
+ max_speakers: Maximum number of speakers per batch (default 3)
|
|
|
|
|
+ max_bytes: Maximum UTF-8 bytes per batch (default 300)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ List of batched text strings
|
|
|
|
|
+ """
|
|
|
|
|
+ batches = []
|
|
|
|
|
+ current_batch = []
|
|
|
|
|
+ current_bytes = 0
|
|
|
|
|
+
|
|
|
|
|
+ for turn in turns:
|
|
|
|
|
+ turn_bytes = len(turn.encode("utf-8"))
|
|
|
|
|
+
|
|
|
|
|
+ would_exceed_speakers = len(current_batch) >= max_speakers
|
|
|
|
|
+ would_exceed_bytes = current_bytes + turn_bytes > max_bytes and current_batch
|
|
|
|
|
+
|
|
|
|
|
+ if would_exceed_speakers or would_exceed_bytes:
|
|
|
|
|
+ batches.append("\n".join(current_batch))
|
|
|
|
|
+ current_batch = [turn]
|
|
|
|
|
+ current_bytes = turn_bytes
|
|
|
|
|
+ else:
|
|
|
|
|
+ current_batch.append(turn)
|
|
|
|
|
+ current_bytes += turn_bytes
|
|
|
|
|
+
|
|
|
|
|
+ if current_batch:
|
|
|
|
|
+ batches.append("\n".join(current_batch))
|
|
|
|
|
+
|
|
|
|
|
+ return batches
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def generate_long(
|
|
def generate_long(
|
|
|
*,
|
|
*,
|
|
|
model,
|
|
model,
|
|
@@ -398,9 +524,10 @@ def generate_long(
|
|
|
text: str,
|
|
text: str,
|
|
|
num_samples: int = 1,
|
|
num_samples: int = 1,
|
|
|
max_new_tokens: int = 0,
|
|
max_new_tokens: int = 0,
|
|
|
- top_p: float = 0.8,
|
|
|
|
|
|
|
+ top_p: float = 0.9,
|
|
|
|
|
+ top_k: int = 30,
|
|
|
repetition_penalty: float = 1.1,
|
|
repetition_penalty: float = 1.1,
|
|
|
- temperature: float = 0.8,
|
|
|
|
|
|
|
+ temperature: float = 1.0,
|
|
|
compile: bool = False,
|
|
compile: bool = False,
|
|
|
iterative_prompt: bool = True,
|
|
iterative_prompt: bool = True,
|
|
|
chunk_length: int = 512,
|
|
chunk_length: int = 512,
|
|
@@ -408,10 +535,9 @@ def generate_long(
|
|
|
prompt_tokens: Optional[Union[torch.Tensor, list[torch.Tensor]]] = None,
|
|
prompt_tokens: Optional[Union[torch.Tensor, list[torch.Tensor]]] = None,
|
|
|
):
|
|
):
|
|
|
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
|
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
|
|
- assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
|
|
|
|
assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
|
assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
|
|
|
|
|
|
|
- use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
|
|
|
|
|
+ use_prompt = bool(prompt_text) and bool(prompt_tokens)
|
|
|
if use_prompt and isinstance(prompt_text, str):
|
|
if use_prompt and isinstance(prompt_text, str):
|
|
|
prompt_text = [prompt_text]
|
|
prompt_text = [prompt_text]
|
|
|
prompt_tokens = [prompt_tokens]
|
|
prompt_tokens = [prompt_tokens]
|
|
@@ -426,91 +552,188 @@ def generate_long(
|
|
|
|
|
|
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
tokenizer = model.tokenizer
|
|
tokenizer = model.tokenizer
|
|
|
- base_content_sequence = ContentSequence(modality="interleave")
|
|
|
|
|
-
|
|
|
|
|
max_length = model.config.max_seq_len
|
|
max_length = model.config.max_seq_len
|
|
|
|
|
+
|
|
|
|
|
+ # Build base conversation with system message
|
|
|
|
|
+ base_conversation = Conversation()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
if use_prompt:
|
|
if use_prompt:
|
|
|
- for t, c in zip(prompt_text, prompt_tokens):
|
|
|
|
|
- base_content_sequence.append(
|
|
|
|
|
- [
|
|
|
|
|
- TextPart(text=t),
|
|
|
|
|
- VQPart(codes=c),
|
|
|
|
|
- ],
|
|
|
|
|
- add_end=True,
|
|
|
|
|
- speaker=0,
|
|
|
|
|
- )
|
|
|
|
|
- base_content_sequence.append(
|
|
|
|
|
- [
|
|
|
|
|
- TextPart(text=text),
|
|
|
|
|
- ],
|
|
|
|
|
- add_end=False,
|
|
|
|
|
- speaker=0,
|
|
|
|
|
|
|
+ # Auto-add speaker tags to prompt texts that don't have them
|
|
|
|
|
+ tagged_prompt_text = []
|
|
|
|
|
+ for i, t in enumerate(prompt_text):
|
|
|
|
|
+ if not re.search(r"<\|speaker:\d+\|>", t):
|
|
|
|
|
+ tagged_prompt_text.append(f"<|speaker:{i}|>{t}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ tagged_prompt_text.append(t)
|
|
|
|
|
+
|
|
|
|
|
+ system_parts = [
|
|
|
|
|
+ TextPart(
|
|
|
|
|
+ text="convert the provided text to speech reference to the following:\n\nText:\n",
|
|
|
|
|
+ cal_loss=False,
|
|
|
|
|
+ ),
|
|
|
|
|
+ ]
|
|
|
|
|
+ reference_text = "\n".join(tagged_prompt_text)
|
|
|
|
|
+ system_parts.append(TextPart(text=reference_text, cal_loss=False))
|
|
|
|
|
+ system_parts.append(TextPart(text="\n\nSpeech:\n", cal_loss=False))
|
|
|
|
|
+ all_codes = torch.cat([c for c in prompt_tokens], dim=1)
|
|
|
|
|
+ system_parts.append(VQPart(codes=all_codes, cal_loss=False))
|
|
|
|
|
+ # torch.save(all_codes, "debug_vq_codes.pt")
|
|
|
|
|
+ else:
|
|
|
|
|
+ system_parts = [
|
|
|
|
|
+ TextPart(text="convert the provided text to speech", cal_loss=False)
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ base_conversation.append(
|
|
|
|
|
+ Message(
|
|
|
|
|
+ role="system",
|
|
|
|
|
+ parts=system_parts,
|
|
|
|
|
+ cal_loss=False,
|
|
|
|
|
+ add_im_start=True,
|
|
|
|
|
+ add_im_end=True,
|
|
|
|
|
+ )
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- encoded, audio_masks, audio_parts = base_content_sequence.encode_for_inference(
|
|
|
|
|
- tokenizer, num_codebooks=model.config.num_codebooks
|
|
|
|
|
- )
|
|
|
|
|
- if encoded.size(1) > max_length - 2048:
|
|
|
|
|
- raise ValueError(f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}")
|
|
|
|
|
|
|
+ # Split text by speaker and group into batches
|
|
|
|
|
+ turns = split_text_by_speaker(text)
|
|
|
|
|
+ if turns:
|
|
|
|
|
+ batches = group_turns_into_batches(
|
|
|
|
|
+ turns, max_speakers=5, max_bytes=chunk_length
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ batches = [text]
|
|
|
|
|
|
|
|
- encoded = encoded.to(device=device)
|
|
|
|
|
- logger.info(f"Encoded text: {text}")
|
|
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"Split into {len(turns)} turns, grouped into {len(batches)} batches"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
for sample_idx in range(num_samples):
|
|
for sample_idx in range(num_samples):
|
|
|
if torch.cuda.is_available():
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.synchronize()
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
- global_encoded = []
|
|
|
|
|
- seg_idx = 0
|
|
|
|
|
- prompt_length = encoded.size(1)
|
|
|
|
|
-
|
|
|
|
|
t0 = time.perf_counter()
|
|
t0 = time.perf_counter()
|
|
|
|
|
|
|
|
- y = generate(
|
|
|
|
|
- model=model,
|
|
|
|
|
- prompt=encoded,
|
|
|
|
|
- max_new_tokens=max_new_tokens,
|
|
|
|
|
- audio_masks=audio_masks,
|
|
|
|
|
- audio_parts=audio_parts,
|
|
|
|
|
- decode_one_token=decode_one_token,
|
|
|
|
|
- temperature=temperature,
|
|
|
|
|
- top_p=top_p,
|
|
|
|
|
- repetition_penalty=repetition_penalty,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # Deep copy base conversation for this sample
|
|
|
|
|
+ conversation = deepcopy(base_conversation)
|
|
|
|
|
|
|
|
- if sample_idx == 0 and seg_idx == 0 and compile:
|
|
|
|
|
- logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
|
|
|
|
|
|
+ for batch_idx, batch_text in enumerate(batches):
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"--- Sample {sample_idx}, Batch {batch_idx} "
|
|
|
|
|
+ f"({len(batch_text.encode('utf-8'))} bytes) ---"
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.info(f"Batch text: {batch_text}")
|
|
|
|
|
+
|
|
|
|
|
+ # Add user message
|
|
|
|
|
+ conversation.append(
|
|
|
|
|
+ Message(
|
|
|
|
|
+ role="user",
|
|
|
|
|
+ parts=[TextPart(text=batch_text, cal_loss=False)],
|
|
|
|
|
+ cal_loss=False,
|
|
|
|
|
+ add_im_start=True,
|
|
|
|
|
+ add_im_end=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- if torch.cuda.is_available():
|
|
|
|
|
- torch.cuda.synchronize()
|
|
|
|
|
|
|
+ # Deep copy for generation (don't pollute original conversation)
|
|
|
|
|
+ conversation_gen = deepcopy(conversation)
|
|
|
|
|
+ conversation_gen.append(
|
|
|
|
|
+ Message(
|
|
|
|
|
+ role="assistant",
|
|
|
|
|
+ parts=[],
|
|
|
|
|
+ cal_loss=False,
|
|
|
|
|
+ modality="voice",
|
|
|
|
|
+ add_im_start=True,
|
|
|
|
|
+ add_im_end=False,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- t = time.perf_counter() - t0
|
|
|
|
|
|
|
+ logger.info("Visualizing prompt structure:")
|
|
|
|
|
+ conversation_gen.visualize(
|
|
|
|
|
+ tokenizer,
|
|
|
|
|
+ merge_audio_tokens=True,
|
|
|
|
|
+ merge_semantic_tokens=True,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- tokens_generated = y.size(1) - prompt_length
|
|
|
|
|
- tokens_sec = tokens_generated / t
|
|
|
|
|
- logger.info(
|
|
|
|
|
- f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
|
|
|
|
|
- )
|
|
|
|
|
- logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
|
|
|
|
|
|
|
+ encoded, audio_masks, audio_parts = (
|
|
|
|
|
+ conversation_gen.encode_for_inference(
|
|
|
|
|
+ tokenizer, num_codebooks=model.config.num_codebooks
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- if torch.cuda.is_available():
|
|
|
|
|
|
|
+ logger.info(f"Encoded prompt shape: {encoded.shape}")
|
|
|
|
|
+ if audio_parts is not None:
|
|
|
|
|
+ logger.info(f"Audio parts shape: {audio_parts.shape}")
|
|
|
|
|
+ if audio_masks is not None:
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"Audio masks non-zero count: {torch.count_nonzero(audio_masks)}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if encoded.size(1) > max_length - 2048:
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ f"Prompt is too long: {encoded.size(1)} > {max_length - 2048}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ encoded = encoded.to(device=device)
|
|
|
|
|
+ prompt_length = encoded.size(1)
|
|
|
|
|
+
|
|
|
|
|
+ y = generate(
|
|
|
|
|
+ model=model,
|
|
|
|
|
+ prompt=encoded,
|
|
|
|
|
+ max_new_tokens=max_new_tokens,
|
|
|
|
|
+ audio_masks=audio_masks,
|
|
|
|
|
+ audio_parts=audio_parts,
|
|
|
|
|
+ decode_one_token=decode_one_token,
|
|
|
|
|
+ temperature=temperature,
|
|
|
|
|
+ top_p=top_p,
|
|
|
|
|
+ top_k=top_k,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if sample_idx == 0 and batch_idx == 0 and compile:
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"Compilation time: {time.perf_counter() - t0:.2f} seconds"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
|
+ torch.cuda.synchronize()
|
|
|
|
|
+
|
|
|
|
|
+ t_batch = time.perf_counter() - t0
|
|
|
|
|
+ tokens_generated = y.size(1) - prompt_length
|
|
|
|
|
+ tokens_sec = tokens_generated / t_batch if t_batch > 0 else 0
|
|
|
logger.info(
|
|
logger.info(
|
|
|
- f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
|
|
|
|
|
|
+ f"Batch {batch_idx}: Generated {tokens_generated} tokens in "
|
|
|
|
|
+ f"{t_batch:.02f} seconds, {tokens_sec:.02f} tokens/sec"
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # Put the generated tokens
|
|
|
|
|
- codes = y[1:, prompt_length:-1].clone()
|
|
|
|
|
- assert (codes >= 0).all(), f"Negative code found"
|
|
|
|
|
|
|
+ # Extract generated codes
|
|
|
|
|
+ codes = y[1:, prompt_length:-1].clone()
|
|
|
|
|
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
|
|
+
|
|
|
|
|
+ # Add assistant message with generated codes back to conversation
|
|
|
|
|
+ conversation.append(
|
|
|
|
|
+ Message(
|
|
|
|
|
+ role="assistant",
|
|
|
|
|
+ parts=[VQPart(codes=codes.cpu(), cal_loss=False)],
|
|
|
|
|
+ cal_loss=False,
|
|
|
|
|
+ modality="voice",
|
|
|
|
|
+ add_im_start=True,
|
|
|
|
|
+ add_im_end=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- decoded = y[:, prompt_length:].clone()
|
|
|
|
|
- global_encoded.append(decoded.cpu())
|
|
|
|
|
- assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
|
|
|
|
|
+ yield GenerateResponse(
|
|
|
|
|
+ action="sample", codes=codes, text=batch_text
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- yield GenerateResponse(action="sample", codes=codes, text=text)
|
|
|
|
|
- seg_idx += 1
|
|
|
|
|
|
|
+ # Cleanup
|
|
|
|
|
+ del y, encoded
|
|
|
|
|
|
|
|
- # Force GPU memory cleanup
|
|
|
|
|
- del y, decoded, codes
|
|
|
|
|
|
|
+ if torch.cuda.is_available():
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
yield GenerateResponse(action="next")
|
|
yield GenerateResponse(action="next")
|
|
|
|
|
|
|
@@ -585,7 +808,7 @@ def launch_thread_safe_queue(
|
|
|
@click.option(
|
|
@click.option(
|
|
|
"--text",
|
|
"--text",
|
|
|
type=str,
|
|
type=str,
|
|
|
- default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
|
|
|
|
|
|
+ default="<|speaker:0|>你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
|
|
)
|
|
)
|
|
|
@click.option("--prompt-text", type=str, default=None, multiple=True)
|
|
@click.option("--prompt-text", type=str, default=None, multiple=True)
|
|
|
@click.option(
|
|
@click.option(
|
|
@@ -594,15 +817,22 @@ def launch_thread_safe_queue(
|
|
|
default=None,
|
|
default=None,
|
|
|
multiple=True,
|
|
multiple=True,
|
|
|
)
|
|
)
|
|
|
|
|
+@click.option(
|
|
|
|
|
+ "--prompt-audio",
|
|
|
|
|
+ type=click.Path(path_type=Path, exists=True),
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ multiple=True,
|
|
|
|
|
+)
|
|
|
|
|
+@click.option("--output", type=click.Path(path_type=Path), default=None)
|
|
|
@click.option("--num-samples", type=int, default=1)
|
|
@click.option("--num-samples", type=int, default=1)
|
|
|
@click.option("--max-new-tokens", type=int, default=0)
|
|
@click.option("--max-new-tokens", type=int, default=0)
|
|
|
-@click.option("--top-p", type=float, default=0.8)
|
|
|
|
|
-@click.option("--repetition-penalty", type=float, default=1.1)
|
|
|
|
|
-@click.option("--temperature", type=float, default=0.8)
|
|
|
|
|
|
|
+@click.option("--top-p", type=float, default=0.9)
|
|
|
|
|
+@click.option("--top-k", type=int, default=30)
|
|
|
|
|
+@click.option("--temperature", type=float, default=1.0)
|
|
|
@click.option(
|
|
@click.option(
|
|
|
"--checkpoint-path",
|
|
"--checkpoint-path",
|
|
|
type=click.Path(path_type=Path, exists=True),
|
|
type=click.Path(path_type=Path, exists=True),
|
|
|
- default="checkpoints/openaudio-s1-mini",
|
|
|
|
|
|
|
+ default="checkpoints/s2-pro",
|
|
|
)
|
|
)
|
|
|
@click.option("--device", type=str, default="cuda")
|
|
@click.option("--device", type=str, default="cuda")
|
|
|
@click.option("--compile/--no-compile", default=False)
|
|
@click.option("--compile/--no-compile", default=False)
|
|
@@ -610,15 +840,17 @@ def launch_thread_safe_queue(
|
|
|
@click.option("--half/--no-half", default=False)
|
|
@click.option("--half/--no-half", default=False)
|
|
|
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
|
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
|
|
@click.option("--chunk-length", type=int, default=300)
|
|
@click.option("--chunk-length", type=int, default=300)
|
|
|
-@click.option("--output-dir", type=Path, default="temp")
|
|
|
|
|
|
|
+@click.option("--output-dir", type=Path, default="output")
|
|
|
def main(
|
|
def main(
|
|
|
text: str,
|
|
text: str,
|
|
|
prompt_text: Optional[tuple[str, ...]],
|
|
prompt_text: Optional[tuple[str, ...]],
|
|
|
prompt_tokens: Optional[tuple[Path, ...]],
|
|
prompt_tokens: Optional[tuple[Path, ...]],
|
|
|
|
|
+ prompt_audio: Optional[tuple[Path, ...]],
|
|
|
|
|
+ output: Optional[Path],
|
|
|
num_samples: int,
|
|
num_samples: int,
|
|
|
max_new_tokens: int,
|
|
max_new_tokens: int,
|
|
|
- top_p: int,
|
|
|
|
|
- repetition_penalty: float,
|
|
|
|
|
|
|
+ top_p: float,
|
|
|
|
|
+ top_k: int,
|
|
|
temperature: float,
|
|
temperature: float,
|
|
|
checkpoint_path: Path,
|
|
checkpoint_path: Path,
|
|
|
device: str,
|
|
device: str,
|
|
@@ -632,14 +864,26 @@ def main(
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
precision = torch.half if half else torch.bfloat16
|
|
precision = torch.half if half else torch.bfloat16
|
|
|
|
|
|
|
|
|
|
+ if prompt_text and not prompt_audio and not prompt_tokens:
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ "--prompt-text requires either --prompt-audio or --prompt-tokens"
|
|
|
|
|
+ )
|
|
|
if (
|
|
if (
|
|
|
- prompt_text is not None
|
|
|
|
|
- and prompt_tokens is not None
|
|
|
|
|
|
|
+ prompt_text
|
|
|
|
|
+ and prompt_tokens
|
|
|
and len(prompt_text) != len(prompt_tokens)
|
|
and len(prompt_text) != len(prompt_tokens)
|
|
|
):
|
|
):
|
|
|
raise ValueError(
|
|
raise ValueError(
|
|
|
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
|
|
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
|
|
|
)
|
|
)
|
|
|
|
|
+ if (
|
|
|
|
|
+ prompt_text
|
|
|
|
|
+ and prompt_audio
|
|
|
|
|
+ and len(prompt_text) != len(prompt_audio)
|
|
|
|
|
+ ):
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ f"Number of prompt text ({len(prompt_text)}) and prompt audio ({len(prompt_audio)}) should be the same"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
logger.info("Loading model ...")
|
|
logger.info("Loading model ...")
|
|
|
t0 = time.time()
|
|
t0 = time.time()
|
|
@@ -657,8 +901,21 @@ def main(
|
|
|
|
|
|
|
|
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
|
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
|
|
|
|
|
|
|
|
|
+ codec = None
|
|
|
|
|
+ codec_checkpoint = checkpoint_path / "codec.pth"
|
|
|
|
|
+
|
|
|
|
|
+ # Handle prompt: --prompt-audio takes priority over --prompt-tokens
|
|
|
prompt_tokens_list = None
|
|
prompt_tokens_list = None
|
|
|
- if prompt_tokens is not None:
|
|
|
|
|
|
|
+ if prompt_audio:
|
|
|
|
|
+ logger.info("Loading codec model for audio encoding...")
|
|
|
|
|
+ codec = load_codec_model(codec_checkpoint, device, precision)
|
|
|
|
|
+ prompt_tokens_list = [
|
|
|
|
|
+ encode_audio(p, codec, device).cpu() for p in prompt_audio
|
|
|
|
|
+ ]
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"Encoded {len(prompt_audio)} audio file(s) to VQ codes"
|
|
|
|
|
+ )
|
|
|
|
|
+ elif prompt_tokens is not None:
|
|
|
prompt_tokens_list = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
|
|
prompt_tokens_list = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
|
|
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
torch.manual_seed(seed)
|
|
@@ -674,7 +931,7 @@ def main(
|
|
|
num_samples=num_samples,
|
|
num_samples=num_samples,
|
|
|
max_new_tokens=max_new_tokens,
|
|
max_new_tokens=max_new_tokens,
|
|
|
top_p=top_p,
|
|
top_p=top_p,
|
|
|
- repetition_penalty=repetition_penalty,
|
|
|
|
|
|
|
+ top_k=top_k,
|
|
|
temperature=temperature,
|
|
temperature=temperature,
|
|
|
compile=compile,
|
|
compile=compile,
|
|
|
iterative_prompt=iterative_prompt,
|
|
iterative_prompt=iterative_prompt,
|
|
@@ -692,9 +949,29 @@ def main(
|
|
|
logger.info(f"Sampled text: {response.text}")
|
|
logger.info(f"Sampled text: {response.text}")
|
|
|
elif response.action == "next":
|
|
elif response.action == "next":
|
|
|
if codes:
|
|
if codes:
|
|
|
|
|
+ merged_codes = torch.cat(codes, dim=1)
|
|
|
codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
|
|
codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
|
|
|
- np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
|
|
|
|
|
|
|
+ np.save(codes_npy_path, merged_codes.cpu().numpy())
|
|
|
logger.info(f"Saved codes to {codes_npy_path}")
|
|
logger.info(f"Saved codes to {codes_npy_path}")
|
|
|
|
|
+
|
|
|
|
|
+ # Decode to wav if --output is specified
|
|
|
|
|
+ if output:
|
|
|
|
|
+ if codec is None:
|
|
|
|
|
+ logger.info("Loading codec model for audio decoding...")
|
|
|
|
|
+ codec = load_codec_model(
|
|
|
|
|
+ codec_checkpoint, device, precision
|
|
|
|
|
+ )
|
|
|
|
|
+ audio = decode_to_audio(merged_codes.to(device), codec)
|
|
|
|
|
+ import soundfile as sf
|
|
|
|
|
+
|
|
|
|
|
+ out_path = (
|
|
|
|
|
+ str(output)
|
|
|
|
|
+ if num_samples == 1
|
|
|
|
|
+ else str(output.with_stem(f"{output.stem}_{idx}"))
|
|
|
|
|
+ )
|
|
|
|
|
+ sf.write(out_path, audio.cpu().float().numpy(), codec.sample_rate)
|
|
|
|
|
+ logger.info(f"Saved audio to {out_path}")
|
|
|
|
|
+
|
|
|
logger.info(f"Next sample")
|
|
logger.info(f"Next sample")
|
|
|
codes = []
|
|
codes = []
|
|
|
idx += 1
|
|
idx += 1
|
|
@@ -703,4 +980,4 @@ def main(
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
- main()
|
|
|
|
|
|
|
+ main()
|