Просмотр исходного кода

feat:修改裁减freqs的逻辑

zhaohaipeng 1 месяц назад
Родитель
Сommit
cd3af01dab
1 измененных файлов с 25 добавлено и 26 удалено
  1. 25 26
      fish_speech/models/text2semantic/llama.py

+ 25 - 26
fish_speech/models/text2semantic/llama.py

@@ -918,61 +918,58 @@ class Attention(nn.Module):
         # KV Cache + Sliding Window
         # =========================
         start = 0
-        seq_len = seqlen
 
         if self.kv_cache is not None:
+            # update cache
             k, v = self.kv_cache.update(input_pos, k, v)
 
             max_context = 4096  # 可调
 
             if input_pos is not None and seqlen == 1:
+                # decode
                 seq_len = int(input_pos.item()) + 1
             else:
+                # prefill
                 seq_len = k.size(2)
 
             start = max(0, seq_len - max_context)
 
+            # window 裁剪
             k = k[:, :, start:seq_len, :]
             v = v[:, :, start:seq_len, :]
 
             if mask is not None:
                 mask = mask[:, :, :, start:seq_len]
-
-        # ⭐ 当前真实长度(最重要)
-        T = k.size(2)
+        else:
+            seq_len = seqlen
 
         # =========================
-        # RoPE(严格对齐 window
+        # RoPE(核心:以 K 为基准
         # =========================
-        if input_pos is not None and seqlen == 1:
-            # decode
-            pos = input_pos.long()
-
-            if self.kv_cache is not None:
-                pos = pos.clamp(min=start)
-
-            freqs = torch.index_select(freqs_cis, 0, pos)
-        else:
-            # prefill
-            if self.kv_cache is not None:
-                freqs = freqs_cis[start:seq_len]
-            else:
-                freqs = freqs_cis[:seqlen]
-
         T_k = k.size(2)
+
         if self.kv_cache is not None:
             freqs = freqs_cis[start: start + T_k]
         else:
             freqs = freqs_cis[:T_k]
 
         # =========================
-        # Apply RoPE
+        # Apply RoPE(Q/K 分开对齐)
         # =========================
+        T_q = q.size(2)
+
+        print("Q:", q.shape)
+        print("K:", k.shape)
+        print("freqs:", freqs.shape)
+        print("start:", start)
 
-        # Q 对应当前 token(长度 Tq)
-        q = apply_rotary_emb(q, freqs[-q.size(2):])
+        assert k.size(2) == freqs.size(0), f"K vs freqs mismatch"
+        assert q.size(2) <= k.size(2), f"Q longer than K??"
 
-        # K 对应整个 window(长度 Tk)
+        # Q 用最后 T_q 个位置
+        q = apply_rotary_emb(q, freqs[-T_q:])
+
+        # K 用完整 window
         k = apply_rotary_emb(k, freqs)
 
         # =========================
@@ -988,7 +985,7 @@ class Attention(nn.Module):
         # =========================
         if self.use_sdpa:
             if mask is None:
-                y = F.scaled_dot_product_attention(
+                y = torch.nn.functional.scaled_dot_product_attention(
                     q,
                     k,
                     v,
@@ -996,7 +993,7 @@ class Attention(nn.Module):
                     is_causal=True,
                 )
             else:
-                y = F.scaled_dot_product_attention(
+                y = torch.nn.functional.scaled_dot_product_attention(
                     q,
                     k,
                     v,
@@ -1018,6 +1015,8 @@ class Attention(nn.Module):
         y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
 
         return self.wo(y)
+
+
     def eq_scaled_dot_product_attention(
         self,
         query,