Преглед изворни кода

feat:修改裁减freqs_cis的逻辑

zhaohaipeng пре 1 месец
родитељ
комит
884d8d2fd8
1 измењених фајлова са 58 додато и 40 уклоњено
  1. 58 40
      fish_speech/models/text2semantic/llama.py

+ 58 - 40
fish_speech/models/text2semantic/llama.py

@@ -887,16 +887,20 @@ class Attention(nn.Module):
             state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
 
     def forward(
-        self,
-        x: Tensor,
-        freqs_cis: Tensor,
-        mask: Tensor,
-        input_pos: Optional[Tensor] = None,
+            self,
+            x: Tensor,
+            freqs_cis: Tensor,
+            mask: Tensor,
+            input_pos: Optional[Tensor] = None,
     ) -> Tensor:
         bsz, seqlen, _ = x.shape
 
         q_size = self.n_head * self.head_dim
         kv_size = self.n_local_heads * self.head_dim
+
+        # =========================
+        # QKV projection
+        # =========================
         q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
 
         q = q.view(bsz, seqlen, self.n_head, self.head_dim)
@@ -907,49 +911,61 @@ class Attention(nn.Module):
             q = self.q_norm(q)
             k = self.k_norm(k)
 
-        # q = apply_rotary_emb(q, freqs_cis)
-        # k = apply_rotary_emb(k, freqs_cis)
-
+        # [B, H, T, D]
         q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
 
+        # =========================
+        # KV Cache + Sliding Window
+        # =========================
         if self.kv_cache is not None:
             k, v = self.kv_cache.update(input_pos, k, v)
 
-            if input_pos is not None:
-                # =========================
-                # 🔥 KV cache window 裁剪(核心优化)
-                # =========================
-                max_context = 4096  # ⭐ 推荐 4K 或 8K
-
-                # 当前有效长度
-                seq_len = int(input_pos.max().item()) + 1
-
-                # window 起点
-                start = max(0, seq_len - max_context)
-
-                # 裁剪 KV
-                k = k[:, :, start:seq_len, :]
-                v = v[:, :, start:seq_len, :]
-
-                # =========================
-                # 🔥 同步裁剪 mask(如果有)
-                # =========================
-                if mask is not None:
-                    mask = mask[:, :, :, start:seq_len]
-
-                # =========================
-                # 🔥 同步裁剪 RoPE(关键,不然会炸)
-                # =========================
-                print(f"input_pos.dtype {input_pos.dtype}")
-                assert input_pos.dtype == torch.long
-                freqs_cis = torch.index_select(freqs_cis, 0, input_pos.long())
-
-        q = apply_rotary_emb(q, freqs_cis)
-        k = apply_rotary_emb(k, freqs_cis)
+            max_context = 4096  # ⭐ 可调:4096 / 8192
 
+            if input_pos is not None and seqlen == 1:
+                # ===== decode 阶段 =====
+                # input_pos: [1]
+                seq_len = int(input_pos.item()) + 1
+            else:
+                # ===== prefill 阶段 =====
+                seq_len = seqlen
+
+            start = max(0, seq_len - max_context)
+
+            # window 裁剪
+            k = k[:, :, start:seq_len, :]
+            v = v[:, :, start:seq_len, :]
+
+            # mask 同步
+            if mask is not None:
+                mask = mask[:, :, :, start:seq_len]
+
+        # =========================
+        # RoPE(关键:分支处理)
+        # =========================
+        if input_pos is not None and seqlen == 1:
+            # ===== decode =====
+            # 必须 index_select(避免 advanced indexing bug)
+            freqs = torch.index_select(freqs_cis, 0, input_pos.long())
+        else:
+            # ===== prefill =====
+            freqs = freqs_cis[:seqlen]
+
+        # =========================
+        # Apply RoPE
+        # =========================
+        q = apply_rotary_emb(q, freqs)
+        k = apply_rotary_emb(k, freqs)
+
+        # =========================
+        # GQA expand
+        # =========================
         k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
         v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
 
+        # =========================
+        # Attention
+        # =========================
         if self.use_sdpa:
             if mask is None:
                 with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
@@ -959,7 +975,6 @@ class Attention(nn.Module):
                         v,
                         dropout_p=self.dropout if self.training else 0.0,
                         is_causal=True,
-                        # No third party attn_mask here to use flash_attention
                     )
             else:
                 y = F.scaled_dot_product_attention(
@@ -978,6 +993,9 @@ class Attention(nn.Module):
                 dropout_p=self.dropout if self.training else 0.0,
             )
 
+        # =========================
+        # Output
+        # =========================
         y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
 
         return self.wo(y)