|
@@ -208,7 +208,7 @@ def decode_n_tokens(
|
|
|
|
|
|
|
|
for i in tqdm(range(num_new_tokens)):
|
|
for i in tqdm(range(num_new_tokens)):
|
|
|
f_start = time.perf_counter()
|
|
f_start = time.perf_counter()
|
|
|
- with sdpa_kernel(SDPBackend.MATH):
|
|
|
|
|
|
|
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
|
|
next_token = decode_one_token(
|
|
next_token = decode_one_token(
|
|
|
model=model,
|
|
model=model,
|
|
|
x=cur_token,
|
|
x=cur_token,
|
|
@@ -242,101 +242,6 @@ def decode_n_tokens(
|
|
|
|
|
|
|
|
return torch.cat(new_tokens, dim=1)
|
|
return torch.cat(new_tokens, dim=1)
|
|
|
|
|
|
|
|
-
|
|
|
|
|
-def decode_n_tokens_optimized(
|
|
|
|
|
- model: DualARTransformer,
|
|
|
|
|
- cur_token: torch.Tensor,
|
|
|
|
|
- input_pos: torch.Tensor,
|
|
|
|
|
- num_new_tokens: int,
|
|
|
|
|
- temperature: torch.Tensor,
|
|
|
|
|
- top_p: torch.Tensor,
|
|
|
|
|
- top_k: int,
|
|
|
|
|
- semantic_logit_bias: torch.Tensor,
|
|
|
|
|
- audio_masks: torch.Tensor,
|
|
|
|
|
- audio_parts: torch.Tensor,
|
|
|
|
|
- decode_one_token=decode_one_token_ar,
|
|
|
|
|
-):
|
|
|
|
|
- """
|
|
|
|
|
- Optimized version:
|
|
|
|
|
- - no roll (ring buffer)
|
|
|
|
|
- - flash attention
|
|
|
|
|
- - reduced view/reshape
|
|
|
|
|
- """
|
|
|
|
|
-
|
|
|
|
|
- device = cur_token.device
|
|
|
|
|
- num_streams = model.config.num_codebooks + 1
|
|
|
|
|
-
|
|
|
|
|
- # =========================
|
|
|
|
|
- # 1. ring buffer index (替代 roll)
|
|
|
|
|
- # =========================
|
|
|
|
|
- previous_tokens = torch.zeros(
|
|
|
|
|
- (model.config.num_codebooks + 1, RAS_WIN_SIZE),
|
|
|
|
|
- dtype=torch.int,
|
|
|
|
|
- device=cur_token.device,
|
|
|
|
|
- )
|
|
|
|
|
- history_len = previous_tokens.size(1)
|
|
|
|
|
- write_idx = history_len - 1
|
|
|
|
|
-
|
|
|
|
|
- new_tokens = []
|
|
|
|
|
-
|
|
|
|
|
- # =========================
|
|
|
|
|
- # 2. precompute reshape shape
|
|
|
|
|
- # =========================
|
|
|
|
|
- batch = 1
|
|
|
|
|
- im_end_id = model.tokenizer.get_token_id(IM_END_TOKEN)
|
|
|
|
|
-
|
|
|
|
|
- # =========================
|
|
|
|
|
- # 3. main loop
|
|
|
|
|
- # =========================
|
|
|
|
|
- for i in range(num_new_tokens):
|
|
|
|
|
-
|
|
|
|
|
- # ⚡ use flash attention (重要优化)
|
|
|
|
|
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
|
|
|
|
- next_token = decode_one_token(
|
|
|
|
|
- model=model,
|
|
|
|
|
- x=cur_token,
|
|
|
|
|
- input_pos=input_pos,
|
|
|
|
|
- previous_tokens=previous_tokens,
|
|
|
|
|
- temperature=temperature,
|
|
|
|
|
- top_p=top_p,
|
|
|
|
|
- top_k=top_k,
|
|
|
|
|
- semantic_logit_bias=semantic_logit_bias,
|
|
|
|
|
- audio_masks=audio_masks,
|
|
|
|
|
- audio_parts=audio_parts,
|
|
|
|
|
- ).clone()
|
|
|
|
|
-
|
|
|
|
|
- # =========================
|
|
|
|
|
- # 4. update position
|
|
|
|
|
- # =========================
|
|
|
|
|
- input_pos += 1
|
|
|
|
|
-
|
|
|
|
|
- # =========================
|
|
|
|
|
- # 5. reshape once (reuse view logic)
|
|
|
|
|
- # =========================
|
|
|
|
|
- next_token_2d = next_token.view(num_streams, -1)
|
|
|
|
|
-
|
|
|
|
|
- cur_token = next_token_2d.unsqueeze(0)
|
|
|
|
|
-
|
|
|
|
|
- # =========================
|
|
|
|
|
- # 6. ring buffer update (NO roll)
|
|
|
|
|
- # =========================
|
|
|
|
|
- previous_tokens[:, write_idx] = next_token_2d[:, 0]
|
|
|
|
|
- write_idx = (write_idx + 1) % history_len
|
|
|
|
|
-
|
|
|
|
|
- # =========================
|
|
|
|
|
- # 7. store output
|
|
|
|
|
- # =========================
|
|
|
|
|
- new_tokens.append(next_token)
|
|
|
|
|
-
|
|
|
|
|
- # =========================
|
|
|
|
|
- # 8. EOS check
|
|
|
|
|
- # =========================
|
|
|
|
|
- if cur_token[0, 0, -1] == im_end_id:
|
|
|
|
|
- break
|
|
|
|
|
-
|
|
|
|
|
- return new_tokens
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
@torch.no_grad()
|
|
@torch.no_grad()
|
|
|
@torch.inference_mode()
|
|
@torch.inference_mode()
|
|
|
def generate(
|
|
def generate(
|
|
@@ -476,7 +381,7 @@ def generate(
|
|
|
# =========================
|
|
# =========================
|
|
|
previous_tokens[:, -1, :] = first_token.view(codebook_dim)
|
|
previous_tokens[:, -1, :] = first_token.view(codebook_dim)
|
|
|
|
|
|
|
|
- x = decode_n_tokens_optimized(
|
|
|
|
|
|
|
+ x = decode_n_tokens(
|
|
|
model,
|
|
model,
|
|
|
first_token.view(1, codebook_dim, -1),
|
|
first_token.view(1, codebook_dim, -1),
|
|
|
input_pos,
|
|
input_pos,
|