|
|
@@ -918,61 +918,58 @@ class Attention(nn.Module):
|
|
|
# KV Cache + Sliding Window
|
|
|
# =========================
|
|
|
start = 0
|
|
|
- seq_len = seqlen
|
|
|
|
|
|
if self.kv_cache is not None:
|
|
|
+ # update cache
|
|
|
k, v = self.kv_cache.update(input_pos, k, v)
|
|
|
|
|
|
max_context = 4096 # 可调
|
|
|
|
|
|
if input_pos is not None and seqlen == 1:
|
|
|
+ # decode
|
|
|
seq_len = int(input_pos.item()) + 1
|
|
|
else:
|
|
|
+ # prefill
|
|
|
seq_len = k.size(2)
|
|
|
|
|
|
start = max(0, seq_len - max_context)
|
|
|
|
|
|
+ # window 裁剪
|
|
|
k = k[:, :, start:seq_len, :]
|
|
|
v = v[:, :, start:seq_len, :]
|
|
|
|
|
|
if mask is not None:
|
|
|
mask = mask[:, :, :, start:seq_len]
|
|
|
-
|
|
|
- # ⭐ 当前真实长度(最重要)
|
|
|
- T = k.size(2)
|
|
|
+ else:
|
|
|
+ seq_len = seqlen
|
|
|
|
|
|
# =========================
|
|
|
- # RoPE(严格对齐 window)
|
|
|
+ # RoPE(核心:以 K 为基准)
|
|
|
# =========================
|
|
|
- if input_pos is not None and seqlen == 1:
|
|
|
- # decode
|
|
|
- pos = input_pos.long()
|
|
|
-
|
|
|
- if self.kv_cache is not None:
|
|
|
- pos = pos.clamp(min=start)
|
|
|
-
|
|
|
- freqs = torch.index_select(freqs_cis, 0, pos)
|
|
|
- else:
|
|
|
- # prefill
|
|
|
- if self.kv_cache is not None:
|
|
|
- freqs = freqs_cis[start:seq_len]
|
|
|
- else:
|
|
|
- freqs = freqs_cis[:seqlen]
|
|
|
-
|
|
|
T_k = k.size(2)
|
|
|
+
|
|
|
if self.kv_cache is not None:
|
|
|
freqs = freqs_cis[start: start + T_k]
|
|
|
else:
|
|
|
freqs = freqs_cis[:T_k]
|
|
|
|
|
|
# =========================
|
|
|
- # Apply RoPE
|
|
|
+ # Apply RoPE(Q/K 分开对齐)
|
|
|
# =========================
|
|
|
+ T_q = q.size(2)
|
|
|
+
|
|
|
+ print("Q:", q.shape)
|
|
|
+ print("K:", k.shape)
|
|
|
+ print("freqs:", freqs.shape)
|
|
|
+ print("start:", start)
|
|
|
|
|
|
- # Q 对应当前 token(长度 Tq)
|
|
|
- q = apply_rotary_emb(q, freqs[-q.size(2):])
|
|
|
+ assert k.size(2) == freqs.size(0), f"K vs freqs mismatch"
|
|
|
+ assert q.size(2) <= k.size(2), f"Q longer than K??"
|
|
|
|
|
|
- # K 对应整个 window(长度 Tk)
|
|
|
+ # Q 用最后 T_q 个位置
|
|
|
+ q = apply_rotary_emb(q, freqs[-T_q:])
|
|
|
+
|
|
|
+ # K 用完整 window
|
|
|
k = apply_rotary_emb(k, freqs)
|
|
|
|
|
|
# =========================
|
|
|
@@ -988,7 +985,7 @@ class Attention(nn.Module):
|
|
|
# =========================
|
|
|
if self.use_sdpa:
|
|
|
if mask is None:
|
|
|
- y = F.scaled_dot_product_attention(
|
|
|
+ y = torch.nn.functional.scaled_dot_product_attention(
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
|
@@ -996,7 +993,7 @@ class Attention(nn.Module):
|
|
|
is_causal=True,
|
|
|
)
|
|
|
else:
|
|
|
- y = F.scaled_dot_product_attention(
|
|
|
+ y = torch.nn.functional.scaled_dot_product_attention(
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
|
@@ -1018,6 +1015,8 @@ class Attention(nn.Module):
|
|
|
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
|
|
|
|
|
|
return self.wo(y)
|
|
|
+
|
|
|
+
|
|
|
def eq_scaled_dot_product_attention(
|
|
|
self,
|
|
|
query,
|