|
|
@@ -888,11 +888,11 @@ class Attention(nn.Module):
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
- x: Tensor,
|
|
|
- freqs_cis: Tensor,
|
|
|
- mask: Tensor,
|
|
|
- input_pos: Optional[Tensor] = None,
|
|
|
- ) -> Tensor:
|
|
|
+ x: torch.Tensor,
|
|
|
+ freqs_cis: torch.Tensor,
|
|
|
+ mask: Optional[torch.Tensor],
|
|
|
+ input_pos: Optional[torch.Tensor] = None,
|
|
|
+ ):
|
|
|
bsz, seqlen, _ = x.shape
|
|
|
|
|
|
q_size = self.n_head * self.head_dim
|
|
|
@@ -917,65 +917,81 @@ class Attention(nn.Module):
|
|
|
# =========================
|
|
|
# KV Cache + Sliding Window
|
|
|
# =========================
|
|
|
+ start = 0
|
|
|
+ seq_len = seqlen
|
|
|
+
|
|
|
if self.kv_cache is not None:
|
|
|
k, v = self.kv_cache.update(input_pos, k, v)
|
|
|
|
|
|
- max_context = 4096 # ⭐ 可调:4096 / 8192
|
|
|
+ max_context = 4096 # 可调
|
|
|
|
|
|
if input_pos is not None and seqlen == 1:
|
|
|
- # ===== decode 阶段 =====
|
|
|
- # input_pos: [1]
|
|
|
seq_len = int(input_pos.item()) + 1
|
|
|
else:
|
|
|
- # ===== prefill 阶段 =====
|
|
|
- seq_len = seqlen
|
|
|
+ seq_len = k.size(2)
|
|
|
|
|
|
start = max(0, seq_len - max_context)
|
|
|
|
|
|
- # window 裁剪
|
|
|
k = k[:, :, start:seq_len, :]
|
|
|
v = v[:, :, start:seq_len, :]
|
|
|
|
|
|
- # mask 同步
|
|
|
if mask is not None:
|
|
|
mask = mask[:, :, :, start:seq_len]
|
|
|
|
|
|
+ # ⭐ 当前真实长度(最重要)
|
|
|
+ T = k.size(2)
|
|
|
+
|
|
|
# =========================
|
|
|
- # RoPE(关键:分支处理)
|
|
|
+ # RoPE(严格对齐 window)
|
|
|
# =========================
|
|
|
if input_pos is not None and seqlen == 1:
|
|
|
- # ===== decode =====
|
|
|
- # 必须 index_select(避免 advanced indexing bug)
|
|
|
- freqs = torch.index_select(freqs_cis, 0, input_pos.long())
|
|
|
+ # decode
|
|
|
+ pos = input_pos.long()
|
|
|
+
|
|
|
+ if self.kv_cache is not None:
|
|
|
+ pos = pos.clamp(min=start)
|
|
|
+
|
|
|
+ freqs = torch.index_select(freqs_cis, 0, pos)
|
|
|
else:
|
|
|
- # ===== prefill =====
|
|
|
- freqs = freqs_cis[:seqlen]
|
|
|
+ # prefill
|
|
|
+ if self.kv_cache is not None:
|
|
|
+ freqs = freqs_cis[start:seq_len]
|
|
|
+ else:
|
|
|
+ freqs = freqs_cis[:seqlen]
|
|
|
+
|
|
|
+ # ⭐ 防御性对齐(关键)
|
|
|
+ freqs = freqs[:T]
|
|
|
|
|
|
# =========================
|
|
|
# Apply RoPE
|
|
|
# =========================
|
|
|
+
|
|
|
+ assert q.size(2) == k.size(2), f"QK mismatch: {q.shape} vs {k.shape}"
|
|
|
+ assert k.size(2) == freqs.size(0), f"RoPE mismatch: {k.shape} vs {freqs.shape}"
|
|
|
+
|
|
|
q = apply_rotary_emb(q, freqs)
|
|
|
k = apply_rotary_emb(k, freqs)
|
|
|
|
|
|
# =========================
|
|
|
# GQA expand
|
|
|
# =========================
|
|
|
- k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
- v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
+ if self.n_head != self.n_local_heads:
|
|
|
+ repeat = self.n_head // self.n_local_heads
|
|
|
+ k = k.repeat_interleave(repeat, dim=1)
|
|
|
+ v = v.repeat_interleave(repeat, dim=1)
|
|
|
|
|
|
# =========================
|
|
|
# Attention
|
|
|
# =========================
|
|
|
if self.use_sdpa:
|
|
|
if mask is None:
|
|
|
- with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
|
|
- y = F.scaled_dot_product_attention(
|
|
|
- q,
|
|
|
- k,
|
|
|
- v,
|
|
|
- dropout_p=self.dropout if self.training else 0.0,
|
|
|
- is_causal=True,
|
|
|
- )
|
|
|
+ y = F.scaled_dot_product_attention(
|
|
|
+ q,
|
|
|
+ k,
|
|
|
+ v,
|
|
|
+ dropout_p=self.dropout if self.training else 0.0,
|
|
|
+ is_causal=True,
|
|
|
+ )
|
|
|
else:
|
|
|
y = F.scaled_dot_product_attention(
|
|
|
q,
|
|
|
@@ -999,7 +1015,6 @@ class Attention(nn.Module):
|
|
|
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
|
|
|
|
|
|
return self.wo(y)
|
|
|
-
|
|
|
def eq_scaled_dot_product_attention(
|
|
|
self,
|
|
|
query,
|