Browse Source

feat:修改裁减freqs_cis的逻辑

zhaohaipeng 1 month ago
parent
commit
7366912bf5
1 changed files with 44 additions and 29 deletions
  1. 44 29
      fish_speech/models/text2semantic/llama.py

+ 44 - 29
fish_speech/models/text2semantic/llama.py

@@ -888,11 +888,11 @@ class Attention(nn.Module):
 
     def forward(
             self,
-            x: Tensor,
-            freqs_cis: Tensor,
-            mask: Tensor,
-            input_pos: Optional[Tensor] = None,
-    ) -> Tensor:
+            x: torch.Tensor,
+            freqs_cis: torch.Tensor,
+            mask: Optional[torch.Tensor],
+            input_pos: Optional[torch.Tensor] = None,
+    ):
         bsz, seqlen, _ = x.shape
 
         q_size = self.n_head * self.head_dim
@@ -917,65 +917,81 @@ class Attention(nn.Module):
         # =========================
         # KV Cache + Sliding Window
         # =========================
+        start = 0
+        seq_len = seqlen
+
         if self.kv_cache is not None:
             k, v = self.kv_cache.update(input_pos, k, v)
 
-            max_context = 4096  # 可调:4096 / 8192
+            max_context = 4096  # 可调
 
             if input_pos is not None and seqlen == 1:
-                # ===== decode 阶段 =====
-                # input_pos: [1]
                 seq_len = int(input_pos.item()) + 1
             else:
-                # ===== prefill 阶段 =====
-                seq_len = seqlen
+                seq_len = k.size(2)
 
             start = max(0, seq_len - max_context)
 
-            # window 裁剪
             k = k[:, :, start:seq_len, :]
             v = v[:, :, start:seq_len, :]
 
-            # mask 同步
             if mask is not None:
                 mask = mask[:, :, :, start:seq_len]
 
+        # ⭐ 当前真实长度(最重要)
+        T = k.size(2)
+
         # =========================
-        # RoPE(关键:分支处理)
+        # RoPE(严格对齐 window
         # =========================
         if input_pos is not None and seqlen == 1:
-            # ===== decode =====
-            # 必须 index_select(避免 advanced indexing bug)
-            freqs = torch.index_select(freqs_cis, 0, input_pos.long())
+            # decode
+            pos = input_pos.long()
+
+            if self.kv_cache is not None:
+                pos = pos.clamp(min=start)
+
+            freqs = torch.index_select(freqs_cis, 0, pos)
         else:
-            # ===== prefill =====
-            freqs = freqs_cis[:seqlen]
+            # prefill
+            if self.kv_cache is not None:
+                freqs = freqs_cis[start:seq_len]
+            else:
+                freqs = freqs_cis[:seqlen]
+
+        # ⭐ 防御性对齐(关键)
+        freqs = freqs[:T]
 
         # =========================
         # Apply RoPE
         # =========================
+
+        assert q.size(2) == k.size(2), f"QK mismatch: {q.shape} vs {k.shape}"
+        assert k.size(2) == freqs.size(0), f"RoPE mismatch: {k.shape} vs {freqs.shape}"
+
         q = apply_rotary_emb(q, freqs)
         k = apply_rotary_emb(k, freqs)
 
         # =========================
         # GQA expand
         # =========================
-        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
-        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+        if self.n_head != self.n_local_heads:
+            repeat = self.n_head // self.n_local_heads
+            k = k.repeat_interleave(repeat, dim=1)
+            v = v.repeat_interleave(repeat, dim=1)
 
         # =========================
         # Attention
         # =========================
         if self.use_sdpa:
             if mask is None:
-                with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
-                    y = F.scaled_dot_product_attention(
-                        q,
-                        k,
-                        v,
-                        dropout_p=self.dropout if self.training else 0.0,
-                        is_causal=True,
-                    )
+                y = F.scaled_dot_product_attention(
+                    q,
+                    k,
+                    v,
+                    dropout_p=self.dropout if self.training else 0.0,
+                    is_causal=True,
+                )
             else:
                 y = F.scaled_dot_product_attention(
                     q,
@@ -999,7 +1015,6 @@ class Attention(nn.Module):
         y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
 
         return self.wo(y)
-
     def eq_scaled_dot_product_attention(
         self,
         query,