|
|
@@ -940,7 +940,9 @@ class Attention(nn.Module):
|
|
|
# =========================
|
|
|
# 🔥 同步裁剪 RoPE(关键,不然会炸)
|
|
|
# =========================
|
|
|
- freqs_cis = freqs_cis[input_pos]
|
|
|
+ print(f"input_pos.dtype {input_pos.dtype}")
|
|
|
+ assert input_pos.dtype == torch.long
|
|
|
+ freqs_cis = torch.index_select(freqs_cis, 0, input_pos.long())
|
|
|
|
|
|
q = apply_rotary_emb(q, freqs_cis)
|
|
|
k = apply_rotary_emb(k, freqs_cis)
|