Procházet zdrojové kódy

feat:修改裁减freqs_cis的逻辑

zhaohaipeng před 1 měsícem
rodič
revize
99ce2e862d
1 změnil soubory, kde provedl 8 přidání a 5 odebrání
  1. 8 5
      fish_speech/models/text2semantic/llama.py

+ 8 - 5
fish_speech/models/text2semantic/llama.py

@@ -959,17 +959,20 @@ class Attention(nn.Module):
             else:
                 freqs = freqs_cis[:seqlen]
 
-        # ⭐ 防御性对齐(关键)
-        freqs = freqs[:T]
+        T_k = k.size(2)
+        if self.kv_cache is not None:
+            freqs = freqs_cis[start: start + T_k]
+        else:
+            freqs = freqs_cis[:T_k]
 
         # =========================
         # Apply RoPE
         # =========================
 
-        assert q.size(2) == k.size(2), f"QK mismatch: {q.shape} vs {k.shape}"
-        assert k.size(2) == freqs.size(0), f"RoPE mismatch: {k.shape} vs {freqs.shape}"
+        # Q 对应当前 token(长度 Tq)
+        q = apply_rotary_emb(q, freqs[-q.size(2):])
 
-        q = apply_rotary_emb(q, freqs)
+        # K 对应整个 window(长度 Tk)
         k = apply_rotary_emb(k, freqs)
 
         # =========================