Fix inverse RoPE formula: swap signs on cross terms
This commit is contained in:
@@ -583,9 +583,9 @@ def _apply_inv_rope_bf16(
|
||||
rope = o_f32[..., nope_dim:]
|
||||
y_even = rope[..., 0::2]
|
||||
y_odd = rope[..., 1::2]
|
||||
# Inverse: sin → -sin
|
||||
# Inverse: sin → -sin (swap signs on the cross terms)
|
||||
rope_out = torch.stack(
|
||||
(y_even * cos - y_odd * sin, y_odd * cos + y_even * sin),
|
||||
(y_even * cos + y_odd * sin, y_odd * cos - y_even * sin),
|
||||
dim=-1,
|
||||
).flatten(-2)
|
||||
o_f32 = o_f32.clone()
|
||||
|
||||
Reference in New Issue
Block a user