Sfoglia il codice sorgente

feat:修改裁减freqs_cis的逻辑

zhaohaipeng 1 mese fa
parent
commit
7cfc69b7fa

+ 1 - 1
fish_speech/models/dac/modded_dac.py

@@ -293,7 +293,7 @@ class Attention(nn.Module):
                 # =========================
                 # 🔥 同步裁剪 RoPE(关键,不然会炸)
                 # =========================
-                freqs_cis = freqs_cis[start:seq_len]
+                freqs_cis = freqs_cis[input_pos]
 
         if self.pos_embed_type == "rope":
             q = apply_rotary_emb(q, freqs_cis)

+ 1 - 1
fish_speech/models/text2semantic/llama.py

@@ -940,7 +940,7 @@ class Attention(nn.Module):
                 # =========================
                 # 🔥 同步裁剪 RoPE(关键,不然会炸)
                 # =========================
-                freqs_cis = freqs_cis[start:seq_len]
+                freqs_cis = freqs_cis[input_pos]
 
         q = apply_rotary_emb(q, freqs_cis)
         k = apply_rotary_emb(k, freqs_cis)