Fix inverse RoPE formula: swap signs on cross terms

This commit is contained in:
2026-05-19 03:22:10 +00:00
parent fece06f746
commit 4ed91b81d0

View File

@@ -583,9 +583,9 @@ def _apply_inv_rope_bf16(
rope = o_f32[..., nope_dim:]
y_even = rope[..., 0::2]
y_odd = rope[..., 1::2]
# Inverse: sin → -sin
# Inverse: sin → -sin (swap signs on the cross terms)
rope_out = torch.stack(
(y_even * cos - y_odd * sin, y_odd * cos + y_even * sin),
(y_even * cos + y_odd * sin, y_odd * cos - y_even * sin),
dim=-1,
).flatten(-2)
o_f32 = o_f32.clone()