Przeglądaj źródła

feat:修改裁减freqs_cis的逻辑

zhaohaipeng 1 miesiąc temu
rodzic
commit
2d5f335ef1

+ 3 - 1
fish_speech/models/dac/modded_dac.py

@@ -293,7 +293,9 @@ class Attention(nn.Module):
                 # =========================
                 # 🔥 同步裁剪 RoPE(关键,不然会炸)
                 # =========================
-                freqs_cis = freqs_cis[input_pos]
+                print(f"input_pos.dtype {input_pos.dtype}")
+                assert input_pos.dtype == torch.long
+                freqs_cis = torch.index_select(freqs_cis, 0, input_pos.long())
 
         if self.pos_embed_type == "rope":
             q = apply_rotary_emb(q, freqs_cis)

+ 3 - 1
fish_speech/models/text2semantic/llama.py

@@ -940,7 +940,9 @@ class Attention(nn.Module):
                 # =========================
                 # 🔥 同步裁剪 RoPE(关键,不然会炸)
                 # =========================
-                freqs_cis = freqs_cis[input_pos]
+                print(f"input_pos.dtype {input_pos.dtype}")
+                assert input_pos.dtype == torch.long
+                freqs_cis = torch.index_select(freqs_cis, 0, input_pos.long())
 
         q = apply_rotary_emb(q, freqs_cis)
         k = apply_rotary_emb(k, freqs_cis)