|
@@ -47,7 +47,7 @@ def multinomial_sample_one_no_sync(
|
|
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
|
|
|
|
|
|
|
|
|
|
|
-RAS_WIN_SIZE = 10 # window for Repetition Aware Sampling
|
|
|
|
|
|
|
+RAS_WIN_SIZE = 10 # window for Repetition Aware Sampling
|
|
|
RAS_HIGH_TEMP = 1.0
|
|
RAS_HIGH_TEMP = 1.0
|
|
|
RAS_HIGH_TOP_P = 0.9
|
|
RAS_HIGH_TOP_P = 0.9
|
|
|
|
|
|
|
@@ -116,23 +116,30 @@ def decode_one_token_ar(
|
|
|
biased_logits = logits + semantic_logit_bias
|
|
biased_logits = logits + semantic_logit_bias
|
|
|
|
|
|
|
|
# Normal sample
|
|
# Normal sample
|
|
|
- main_token_normal = sample(biased_logits, temperature=temperature, top_p=top_p, top_k=top_k)[0]
|
|
|
|
|
|
|
+ main_token_normal = sample(
|
|
|
|
|
+ biased_logits, temperature=temperature, top_p=top_p, top_k=top_k
|
|
|
|
|
+ )[0]
|
|
|
|
|
|
|
|
# RAS: also sample with high temp to use as fallback if token repeats
|
|
# RAS: also sample with high temp to use as fallback if token repeats
|
|
|
- high_temp = torch.tensor(RAS_HIGH_TEMP, device=temperature.device, dtype=temperature.dtype)
|
|
|
|
|
|
|
+ high_temp = torch.tensor(
|
|
|
|
|
+ RAS_HIGH_TEMP, device=temperature.device, dtype=temperature.dtype
|
|
|
|
|
+ )
|
|
|
high_top_p = torch.tensor(RAS_HIGH_TOP_P, device=top_p.device, dtype=top_p.dtype)
|
|
high_top_p = torch.tensor(RAS_HIGH_TOP_P, device=top_p.device, dtype=top_p.dtype)
|
|
|
- main_token_high = sample(biased_logits, temperature=high_temp, top_p=high_top_p, top_k=top_k)[0]
|
|
|
|
|
|
|
+ main_token_high = sample(
|
|
|
|
|
+ biased_logits, temperature=high_temp, top_p=high_top_p, top_k=top_k
|
|
|
|
|
+ )[0]
|
|
|
|
|
|
|
|
# Use high-temp sample if: token is semantic AND token is in previous window
|
|
# Use high-temp sample if: token is semantic AND token is in previous window
|
|
|
if previous_tokens is not None:
|
|
if previous_tokens is not None:
|
|
|
in_window = (previous_tokens[0] == main_token_normal).any()
|
|
in_window = (previous_tokens[0] == main_token_normal).any()
|
|
|
# Use tensor ops (&, torch.where) instead of Python (and, if) — torch.compile requires no data-dependent branching
|
|
# Use tensor ops (&, torch.where) instead of Python (and, if) — torch.compile requires no data-dependent branching
|
|
|
- is_semantic = (
|
|
|
|
|
- (main_token_normal >= model.config.semantic_begin_id)
|
|
|
|
|
- & (main_token_normal <= model.config.semantic_end_id)
|
|
|
|
|
|
|
+ is_semantic = (main_token_normal >= model.config.semantic_begin_id) & (
|
|
|
|
|
+ main_token_normal <= model.config.semantic_end_id
|
|
|
)
|
|
)
|
|
|
should_use_high = in_window & is_semantic
|
|
should_use_high = in_window & is_semantic
|
|
|
- main_token_normal = torch.where(should_use_high, main_token_high, main_token_normal)
|
|
|
|
|
|
|
+ main_token_normal = torch.where(
|
|
|
|
|
+ should_use_high, main_token_high, main_token_normal
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
codebooks = [main_token_normal]
|
|
codebooks = [main_token_normal]
|
|
|
|
|
|
|
@@ -144,7 +151,7 @@ def decode_one_token_ar(
|
|
|
|
|
|
|
|
input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
|
|
input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
|
|
|
model.forward_generate_fast(hidden_states, input_pos)
|
|
model.forward_generate_fast(hidden_states, input_pos)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
# [MODIFIED] Access config instead of tokenizer
|
|
# [MODIFIED] Access config instead of tokenizer
|
|
|
a = codebooks[0] - model.config.semantic_begin_id
|
|
a = codebooks[0] - model.config.semantic_begin_id
|
|
|
a[a < 0] = 0
|
|
a[a < 0] = 0
|
|
@@ -158,7 +165,7 @@ def decode_one_token_ar(
|
|
|
)
|
|
)
|
|
|
logits = model.forward_generate_fast(hidden_states, input_pos)
|
|
logits = model.forward_generate_fast(hidden_states, input_pos)
|
|
|
|
|
|
|
|
- short_logits = logits # DualAR predicts config.codebook_size number of tokens
|
|
|
|
|
|
|
+ short_logits = logits # DualAR predicts config.codebook_size number of tokens
|
|
|
|
|
|
|
|
# Convert logits to probs (no constrain for fast codebooks)
|
|
# Convert logits to probs (no constrain for fast codebooks)
|
|
|
a = sample(
|
|
a = sample(
|
|
@@ -200,7 +207,7 @@ def decode_n_tokens(
|
|
|
)
|
|
)
|
|
|
# Accumulate all generated tokens (the actual output)
|
|
# Accumulate all generated tokens (the actual output)
|
|
|
new_tokens = []
|
|
new_tokens = []
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
# [MODIFIED] Pre-fetch ID for efficiency loop
|
|
# [MODIFIED] Pre-fetch ID for efficiency loop
|
|
|
im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
|
|
im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
|
|
|
|
|
|
|
@@ -223,7 +230,9 @@ def decode_n_tokens(
|
|
|
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
|
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
|
|
# Roll RAS window left and insert new token at end
|
|
# Roll RAS window left and insert new token at end
|
|
|
previous_tokens = previous_tokens.roll(-1, dims=1)
|
|
previous_tokens = previous_tokens.roll(-1, dims=1)
|
|
|
- previous_tokens[:, -1] = next_token.view(model.config.num_codebooks + 1, -1)[:, 0]
|
|
|
|
|
|
|
+ previous_tokens[:, -1] = next_token.view(model.config.num_codebooks + 1, -1)[
|
|
|
|
|
+ :, 0
|
|
|
|
|
+ ]
|
|
|
new_tokens.append(next_token)
|
|
new_tokens.append(next_token)
|
|
|
|
|
|
|
|
if cur_token[0, 0, -1] == im_end_id:
|
|
if cur_token[0, 0, -1] == im_end_id:
|
|
@@ -270,7 +279,9 @@ def generate(
|
|
|
max_new_tokens = T_new - T
|
|
max_new_tokens = T_new - T
|
|
|
|
|
|
|
|
device = prompt.device
|
|
device = prompt.device
|
|
|
- dtype = next(model.parameters()).dtype # model weight dtype (bfloat16), NOT prompt dtype (int32)
|
|
|
|
|
|
|
+ dtype = next(
|
|
|
|
|
+ model.parameters()
|
|
|
|
|
+ ).dtype # model weight dtype (bfloat16), NOT prompt dtype (int32)
|
|
|
|
|
|
|
|
# Critical fix: Only set up cache on first run or when necessary
|
|
# Critical fix: Only set up cache on first run or when necessary
|
|
|
if not hasattr(model, "_cache_setup_done") or not model._cache_setup_done:
|
|
if not hasattr(model, "_cache_setup_done") or not model._cache_setup_done:
|
|
@@ -304,12 +315,12 @@ def generate(
|
|
|
semantic_logit_bias = torch.full(
|
|
semantic_logit_bias = torch.full(
|
|
|
(1, 1, vocab_size), float("-inf"), device=device, dtype=dtype
|
|
(1, 1, vocab_size), float("-inf"), device=device, dtype=dtype
|
|
|
)
|
|
)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
# [MODIFIED] Use config for semantic range
|
|
# [MODIFIED] Use config for semantic range
|
|
|
semantic_logit_bias[
|
|
semantic_logit_bias[
|
|
|
0, 0, model.config.semantic_begin_id : model.config.semantic_end_id + 1
|
|
0, 0, model.config.semantic_begin_id : model.config.semantic_end_id + 1
|
|
|
] = 0.0
|
|
] = 0.0
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
# [MODIFIED] Use tokenizer.get_token_id (Wrapper method)
|
|
# [MODIFIED] Use tokenizer.get_token_id (Wrapper method)
|
|
|
semantic_logit_bias[0, 0, model.tokenizer.get_token_id(IM_END_TOKEN)] = 0.0
|
|
semantic_logit_bias[0, 0, model.tokenizer.get_token_id(IM_END_TOKEN)] = 0.0
|
|
|
|
|
|
|
@@ -419,9 +430,7 @@ def encode_audio(audio_path, codec, device):
|
|
|
wav, sr = torchaudio.load(str(audio_path))
|
|
wav, sr = torchaudio.load(str(audio_path))
|
|
|
if wav.shape[0] > 1:
|
|
if wav.shape[0] > 1:
|
|
|
wav = wav.mean(dim=0, keepdim=True)
|
|
wav = wav.mean(dim=0, keepdim=True)
|
|
|
- wav = torchaudio.functional.resample(
|
|
|
|
|
- wav.to(device), sr, codec.sample_rate
|
|
|
|
|
- )[0]
|
|
|
|
|
|
|
+ wav = torchaudio.functional.resample(wav.to(device), sr, codec.sample_rate)[0]
|
|
|
|
|
|
|
|
# Match codec model dtype (e.g. bfloat16)
|
|
# Match codec model dtype (e.g. bfloat16)
|
|
|
model_dtype = next(codec.parameters()).dtype
|
|
model_dtype = next(codec.parameters()).dtype
|
|
@@ -557,7 +566,6 @@ def generate_long(
|
|
|
# Build base conversation with system message
|
|
# Build base conversation with system message
|
|
|
base_conversation = Conversation()
|
|
base_conversation = Conversation()
|
|
|
|
|
|
|
|
-
|
|
|
|
|
if use_prompt:
|
|
if use_prompt:
|
|
|
# Auto-add speaker tags to prompt texts that don't have them
|
|
# Auto-add speaker tags to prompt texts that don't have them
|
|
|
tagged_prompt_text = []
|
|
tagged_prompt_text = []
|
|
@@ -603,9 +611,7 @@ def generate_long(
|
|
|
else:
|
|
else:
|
|
|
batches = [text]
|
|
batches = [text]
|
|
|
|
|
|
|
|
- logger.info(
|
|
|
|
|
- f"Split into {len(turns)} turns, grouped into {len(batches)} batches"
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ logger.info(f"Split into {len(turns)} turns, grouped into {len(batches)} batches")
|
|
|
|
|
|
|
|
for sample_idx in range(num_samples):
|
|
for sample_idx in range(num_samples):
|
|
|
if torch.cuda.is_available():
|
|
if torch.cuda.is_available():
|
|
@@ -654,10 +660,8 @@ def generate_long(
|
|
|
merge_semantic_tokens=True,
|
|
merge_semantic_tokens=True,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- encoded, audio_masks, audio_parts = (
|
|
|
|
|
- conversation_gen.encode_for_inference(
|
|
|
|
|
- tokenizer, num_codebooks=model.config.num_codebooks
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ encoded, audio_masks, audio_parts = conversation_gen.encode_for_inference(
|
|
|
|
|
+ tokenizer, num_codebooks=model.config.num_codebooks
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
logger.info(f"Encoded prompt shape: {encoded.shape}")
|
|
logger.info(f"Encoded prompt shape: {encoded.shape}")
|
|
@@ -689,9 +693,7 @@ def generate_long(
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
if sample_idx == 0 and batch_idx == 0 and compile:
|
|
if sample_idx == 0 and batch_idx == 0 and compile:
|
|
|
- logger.info(
|
|
|
|
|
- f"Compilation time: {time.perf_counter() - t0:.2f} seconds"
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.synchronize()
|
|
torch.cuda.synchronize()
|
|
@@ -723,9 +725,7 @@ def generate_long(
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- yield GenerateResponse(
|
|
|
|
|
- action="sample", codes=codes, text=batch_text
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ yield GenerateResponse(action="sample", codes=codes, text=batch_text)
|
|
|
|
|
|
|
|
# Cleanup
|
|
# Cleanup
|
|
|
del y, encoded
|
|
del y, encoded
|
|
@@ -868,19 +868,11 @@ def main(
|
|
|
raise ValueError(
|
|
raise ValueError(
|
|
|
"--prompt-text requires either --prompt-audio or --prompt-tokens"
|
|
"--prompt-text requires either --prompt-audio or --prompt-tokens"
|
|
|
)
|
|
)
|
|
|
- if (
|
|
|
|
|
- prompt_text
|
|
|
|
|
- and prompt_tokens
|
|
|
|
|
- and len(prompt_text) != len(prompt_tokens)
|
|
|
|
|
- ):
|
|
|
|
|
|
|
+ if prompt_text and prompt_tokens and len(prompt_text) != len(prompt_tokens):
|
|
|
raise ValueError(
|
|
raise ValueError(
|
|
|
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
|
|
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
|
|
|
)
|
|
)
|
|
|
- if (
|
|
|
|
|
- prompt_text
|
|
|
|
|
- and prompt_audio
|
|
|
|
|
- and len(prompt_text) != len(prompt_audio)
|
|
|
|
|
- ):
|
|
|
|
|
|
|
+ if prompt_text and prompt_audio and len(prompt_text) != len(prompt_audio):
|
|
|
raise ValueError(
|
|
raise ValueError(
|
|
|
f"Number of prompt text ({len(prompt_text)}) and prompt audio ({len(prompt_audio)}) should be the same"
|
|
f"Number of prompt text ({len(prompt_text)}) and prompt audio ({len(prompt_audio)}) should be the same"
|
|
|
)
|
|
)
|
|
@@ -912,9 +904,7 @@ def main(
|
|
|
prompt_tokens_list = [
|
|
prompt_tokens_list = [
|
|
|
encode_audio(p, codec, device).cpu() for p in prompt_audio
|
|
encode_audio(p, codec, device).cpu() for p in prompt_audio
|
|
|
]
|
|
]
|
|
|
- logger.info(
|
|
|
|
|
- f"Encoded {len(prompt_audio)} audio file(s) to VQ codes"
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ logger.info(f"Encoded {len(prompt_audio)} audio file(s) to VQ codes")
|
|
|
elif prompt_tokens is not None:
|
|
elif prompt_tokens is not None:
|
|
|
prompt_tokens_list = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
|
|
prompt_tokens_list = [torch.from_numpy(np.load(p)) for p in prompt_tokens]
|
|
|
|
|
|
|
@@ -958,9 +948,7 @@ def main(
|
|
|
if output:
|
|
if output:
|
|
|
if codec is None:
|
|
if codec is None:
|
|
|
logger.info("Loading codec model for audio decoding...")
|
|
logger.info("Loading codec model for audio decoding...")
|
|
|
- codec = load_codec_model(
|
|
|
|
|
- codec_checkpoint, device, precision
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ codec = load_codec_model(codec_checkpoint, device, precision)
|
|
|
audio = decode_to_audio(merged_codes.to(device), codec)
|
|
audio = decode_to_audio(merged_codes.to(device), codec)
|
|
|
import soundfile as sf
|
|
import soundfile as sf
|
|
|
|
|
|
|
@@ -980,4 +968,4 @@ def main(
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
- main()
|
|
|
|
|
|
|
+ main()
|