|
@@ -254,8 +254,6 @@ def decode_n_tokens_optimized(
|
|
|
semantic_logit_bias: torch.Tensor,
|
|
semantic_logit_bias: torch.Tensor,
|
|
|
audio_masks: torch.Tensor,
|
|
audio_masks: torch.Tensor,
|
|
|
audio_parts: torch.Tensor,
|
|
audio_parts: torch.Tensor,
|
|
|
- previous_tokens: torch.Tensor,
|
|
|
|
|
- im_end_id: Any,
|
|
|
|
|
decode_one_token=decode_one_token_ar,
|
|
decode_one_token=decode_one_token_ar,
|
|
|
):
|
|
):
|
|
|
"""
|
|
"""
|
|
@@ -271,6 +269,11 @@ def decode_n_tokens_optimized(
|
|
|
# =========================
|
|
# =========================
|
|
|
# 1. ring buffer index (替代 roll)
|
|
# 1. ring buffer index (替代 roll)
|
|
|
# =========================
|
|
# =========================
|
|
|
|
|
+ previous_tokens = torch.zeros(
|
|
|
|
|
+ (model.config.num_codebooks + 1, RAS_WIN_SIZE),
|
|
|
|
|
+ dtype=torch.int,
|
|
|
|
|
+ device=cur_token.device,
|
|
|
|
|
+ )
|
|
|
history_len = previous_tokens.size(1)
|
|
history_len = previous_tokens.size(1)
|
|
|
write_idx = history_len - 1
|
|
write_idx = history_len - 1
|
|
|
|
|
|
|
@@ -280,6 +283,7 @@ def decode_n_tokens_optimized(
|
|
|
# 2. precompute reshape shape
|
|
# 2. precompute reshape shape
|
|
|
# =========================
|
|
# =========================
|
|
|
batch = 1
|
|
batch = 1
|
|
|
|
|
+ im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
|
|
|
|
|
|
|
|
# =========================
|
|
# =========================
|
|
|
# 3. main loop
|
|
# 3. main loop
|
|
@@ -483,8 +487,6 @@ def generate(
|
|
|
semantic_logit_bias=semantic_logit_bias,
|
|
semantic_logit_bias=semantic_logit_bias,
|
|
|
audio_masks=audio_masks,
|
|
audio_masks=audio_masks,
|
|
|
audio_parts=audio_parts,
|
|
audio_parts=audio_parts,
|
|
|
- im_end_id=im_end_id,
|
|
|
|
|
- previous_tokens=previous_tokens,
|
|
|
|
|
decode_one_token=decode_one_token,
|
|
decode_one_token=decode_one_token,
|
|
|
)
|
|
)
|
|
|
seq = seq[:, : T + 1 + x.size(1)]
|
|
seq = seq[:, : T + 1 + x.size(1)]
|