Fix RoPE shape bug (interleave needs separate even/odd assembly)

This commit is contained in:
2026-05-31 09:15:59 +00:00
parent 9d96c2fbbf
commit a2ee78b564
2 changed files with 41 additions and 19 deletions

View File

@@ -187,11 +187,17 @@ def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
nope = hd - rope_dim
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) FP32
sin = sin_cache[positions].unsqueeze(1)
out = x.clone()
x_rope = x[:, :, nope:].float() # FP32 for accurate rotation
out[:, :, nope:] = (x_rope[..., 0::2] * cos - x_rope[..., 1::2] * sin).to(torch.bfloat16)
out[:, :, nope:][..., 1::2] = (x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos).to(torch.bfloat16)
return out
x_even = x_rope[..., 0::2]
x_odd = x_rope[..., 1::2]
rot_even = x_even * cos - x_odd * sin
rot_odd = x_even * sin + x_odd * cos
result = x.clone()
rope_out = torch.empty_like(x_rope)
rope_out[..., 0::2] = rot_even
rope_out[..., 1::2] = rot_odd
result[:, :, nope:] = rope_out.to(torch.bfloat16)
return result
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
@@ -201,11 +207,17 @@ def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
nope = hd - rope_dim
cos = cos_cache[positions].unsqueeze(1)
sin = sin_cache[positions].unsqueeze(1)
out = o.clone()
o_rope = o[:, :, nope:].float()
out[:, :, nope:] = (o_rope[..., 0::2] * cos + o_rope[..., 1::2] * sin).to(torch.bfloat16)
out[:, :, nope:][..., 1::2] = (-o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos).to(torch.bfloat16)
return out
o_even = o_rope[..., 0::2]
o_odd = o_rope[..., 1::2]
inv_even = o_even * cos + o_odd * sin
inv_odd = -o_even * sin + o_odd * cos
result = o.clone()
rope_out = torch.empty_like(o_rope)
rope_out[..., 0::2] = inv_even
rope_out[..., 1::2] = inv_odd
result[:, :, nope:] = rope_out.to(torch.bfloat16)
return result
class SimpleKVCache:
"""Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.

View File

@@ -71,13 +71,17 @@ def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
nope = hd - rope_dim
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) FP32
sin = sin_cache[positions].unsqueeze(1)
out = x.clone()
# Compute in FP32 for numerical stability
x_rope = x[:, :, nope:].float()
out[:, :, nope:] = (x_rope[..., 0::2] * cos - x_rope[..., 1::2] * sin).to(torch.bfloat16)
# Second pass for odd elements (need original even values)
out[:, :, nope:][..., 1::2] = (x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos).to(torch.bfloat16)
return out
x_rope = x[:, :, nope:].float() # (T, n_h, rope_dim)
x_even = x_rope[..., 0::2] # (T, n_h, half)
x_odd = x_rope[..., 1::2]
rot_even = x_even * cos - x_odd * sin
rot_odd = x_even * sin + x_odd * cos
result = x.clone()
rope_out = torch.empty_like(x_rope)
rope_out[..., 0::2] = rot_even
rope_out[..., 1::2] = rot_odd
result[:, :, nope:] = rope_out.to(torch.bfloat16)
return result
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
"""Apply inverse RoPE (conjugate rotation). Computes in FP32 for accuracy."""
@@ -85,11 +89,17 @@ def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
nope = hd - rope_dim
cos = cos_cache[positions].unsqueeze(1)
sin = sin_cache[positions].unsqueeze(1)
out = o.clone()
o_rope = o[:, :, nope:].float()
out[:, :, nope:] = (o_rope[..., 0::2] * cos + o_rope[..., 1::2] * sin).to(torch.bfloat16)
out[:, :, nope:][..., 1::2] = (-o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos).to(torch.bfloat16)
return out
o_even = o_rope[..., 0::2]
o_odd = o_rope[..., 1::2]
inv_even = o_even * cos + o_odd * sin
inv_odd = -o_even * sin + o_odd * cos
result = o.clone()
rope_out = torch.empty_like(o_rope)
rope_out[..., 0::2] = inv_even
rope_out[..., 1::2] = inv_odd
result[:, :, nope:] = rope_out.to(torch.bfloat16)
return result
def load_weights_to_cpu(checkpoint_dir):
from safetensors.torch import load_file