|
|
@@ -887,16 +887,20 @@ class Attention(nn.Module):
|
|
|
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
|
|
|
|
|
def forward(
|
|
|
- self,
|
|
|
- x: Tensor,
|
|
|
- freqs_cis: Tensor,
|
|
|
- mask: Tensor,
|
|
|
- input_pos: Optional[Tensor] = None,
|
|
|
+ self,
|
|
|
+ x: Tensor,
|
|
|
+ freqs_cis: Tensor,
|
|
|
+ mask: Tensor,
|
|
|
+ input_pos: Optional[Tensor] = None,
|
|
|
) -> Tensor:
|
|
|
bsz, seqlen, _ = x.shape
|
|
|
|
|
|
q_size = self.n_head * self.head_dim
|
|
|
kv_size = self.n_local_heads * self.head_dim
|
|
|
+
|
|
|
+ # =========================
|
|
|
+ # QKV projection
|
|
|
+ # =========================
|
|
|
q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
|
|
|
|
|
|
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
|
|
@@ -907,49 +911,61 @@ class Attention(nn.Module):
|
|
|
q = self.q_norm(q)
|
|
|
k = self.k_norm(k)
|
|
|
|
|
|
- # q = apply_rotary_emb(q, freqs_cis)
|
|
|
- # k = apply_rotary_emb(k, freqs_cis)
|
|
|
-
|
|
|
+ # [B, H, T, D]
|
|
|
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
|
|
|
|
+ # =========================
|
|
|
+ # KV Cache + Sliding Window
|
|
|
+ # =========================
|
|
|
if self.kv_cache is not None:
|
|
|
k, v = self.kv_cache.update(input_pos, k, v)
|
|
|
|
|
|
- if input_pos is not None:
|
|
|
- # =========================
|
|
|
- # 🔥 KV cache window 裁剪(核心优化)
|
|
|
- # =========================
|
|
|
- max_context = 4096 # ⭐ 推荐 4K 或 8K
|
|
|
-
|
|
|
- # 当前有效长度
|
|
|
- seq_len = int(input_pos.max().item()) + 1
|
|
|
-
|
|
|
- # window 起点
|
|
|
- start = max(0, seq_len - max_context)
|
|
|
-
|
|
|
- # 裁剪 KV
|
|
|
- k = k[:, :, start:seq_len, :]
|
|
|
- v = v[:, :, start:seq_len, :]
|
|
|
-
|
|
|
- # =========================
|
|
|
- # 🔥 同步裁剪 mask(如果有)
|
|
|
- # =========================
|
|
|
- if mask is not None:
|
|
|
- mask = mask[:, :, :, start:seq_len]
|
|
|
-
|
|
|
- # =========================
|
|
|
- # 🔥 同步裁剪 RoPE(关键,不然会炸)
|
|
|
- # =========================
|
|
|
- print(f"input_pos.dtype {input_pos.dtype}")
|
|
|
- assert input_pos.dtype == torch.long
|
|
|
- freqs_cis = torch.index_select(freqs_cis, 0, input_pos.long())
|
|
|
-
|
|
|
- q = apply_rotary_emb(q, freqs_cis)
|
|
|
- k = apply_rotary_emb(k, freqs_cis)
|
|
|
+ max_context = 4096 # ⭐ 可调:4096 / 8192
|
|
|
|
|
|
+ if input_pos is not None and seqlen == 1:
|
|
|
+ # ===== decode 阶段 =====
|
|
|
+ # input_pos: [1]
|
|
|
+ seq_len = int(input_pos.item()) + 1
|
|
|
+ else:
|
|
|
+ # ===== prefill 阶段 =====
|
|
|
+ seq_len = seqlen
|
|
|
+
|
|
|
+ start = max(0, seq_len - max_context)
|
|
|
+
|
|
|
+ # window 裁剪
|
|
|
+ k = k[:, :, start:seq_len, :]
|
|
|
+ v = v[:, :, start:seq_len, :]
|
|
|
+
|
|
|
+ # mask 同步
|
|
|
+ if mask is not None:
|
|
|
+ mask = mask[:, :, :, start:seq_len]
|
|
|
+
|
|
|
+ # =========================
|
|
|
+ # RoPE(关键:分支处理)
|
|
|
+ # =========================
|
|
|
+ if input_pos is not None and seqlen == 1:
|
|
|
+ # ===== decode =====
|
|
|
+ # 必须 index_select(避免 advanced indexing bug)
|
|
|
+ freqs = torch.index_select(freqs_cis, 0, input_pos.long())
|
|
|
+ else:
|
|
|
+ # ===== prefill =====
|
|
|
+ freqs = freqs_cis[:seqlen]
|
|
|
+
|
|
|
+ # =========================
|
|
|
+ # Apply RoPE
|
|
|
+ # =========================
|
|
|
+ q = apply_rotary_emb(q, freqs)
|
|
|
+ k = apply_rotary_emb(k, freqs)
|
|
|
+
|
|
|
+ # =========================
|
|
|
+ # GQA expand
|
|
|
+ # =========================
|
|
|
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
|
|
|
|
+ # =========================
|
|
|
+ # Attention
|
|
|
+ # =========================
|
|
|
if self.use_sdpa:
|
|
|
if mask is None:
|
|
|
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
|
|
@@ -959,7 +975,6 @@ class Attention(nn.Module):
|
|
|
v,
|
|
|
dropout_p=self.dropout if self.training else 0.0,
|
|
|
is_causal=True,
|
|
|
- # No third party attn_mask here to use flash_attention
|
|
|
)
|
|
|
else:
|
|
|
y = F.scaled_dot_product_attention(
|
|
|
@@ -978,6 +993,9 @@ class Attention(nn.Module):
|
|
|
dropout_p=self.dropout if self.training else 0.0,
|
|
|
)
|
|
|
|
|
|
+ # =========================
|
|
|
+ # Output
|
|
|
+ # =========================
|
|
|
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
|
|
|
|
|
|
return self.wo(y)
|