Преглед изворни кода

feat:修改裁减freqs_cis的逻辑

zhaohaipeng пре 1 месец
родитељ
комит
7cfc69b7fa
2 измењених фајлова са 2 додато и 2 уклоњено
  1. 1 1
      fish_speech/models/dac/modded_dac.py
  2. 1 1
      fish_speech/models/text2semantic/llama.py

+ 1 - 1
fish_speech/models/dac/modded_dac.py

@@ -293,7 +293,7 @@ class Attention(nn.Module):
                 # =========================
                 # 🔥 同步裁剪 RoPE(关键,不然会炸)
                 # =========================
-                freqs_cis = freqs_cis[start:seq_len]
+                freqs_cis = freqs_cis[input_pos]
 
         if self.pos_embed_type == "rope":
             q = apply_rotary_emb(q, freqs_cis)

+ 1 - 1
fish_speech/models/text2semantic/llama.py

@@ -940,7 +940,7 @@ class Attention(nn.Module):
                 # =========================
                 # 🔥 同步裁剪 RoPE(关键,不然会炸)
                 # =========================
-                freqs_cis = freqs_cis[start:seq_len]
+                freqs_cis = freqs_cis[input_pos]
 
         q = apply_rotary_emb(q, freqs_cis)
         k = apply_rotary_emb(k, freqs_cis)