|
|
@@ -959,17 +959,20 @@ class Attention(nn.Module):
|
|
|
else:
|
|
|
freqs = freqs_cis[:seqlen]
|
|
|
|
|
|
- # ⭐ 防御性对齐(关键)
|
|
|
- freqs = freqs[:T]
|
|
|
+ T_k = k.size(2)
|
|
|
+ if self.kv_cache is not None:
|
|
|
+ freqs = freqs_cis[start: start + T_k]
|
|
|
+ else:
|
|
|
+ freqs = freqs_cis[:T_k]
|
|
|
|
|
|
# =========================
|
|
|
# Apply RoPE
|
|
|
# =========================
|
|
|
|
|
|
- assert q.size(2) == k.size(2), f"QK mismatch: {q.shape} vs {k.shape}"
|
|
|
- assert k.size(2) == freqs.size(0), f"RoPE mismatch: {k.shape} vs {freqs.shape}"
|
|
|
+ # Q 对应当前 token(长度 Tq)
|
|
|
+ q = apply_rotary_emb(q, freqs[-q.size(2):])
|
|
|
|
|
|
- q = apply_rotary_emb(q, freqs)
|
|
|
+ # K 对应整个 window(长度 Tk)
|
|
|
k = apply_rotary_emb(k, freqs)
|
|
|
|
|
|
# =========================
|