Fix RoPE shape bug (interleave needs separate even/odd assembly)
This commit is contained in:
@@ -187,11 +187,17 @@ def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||||
nope = hd - rope_dim
|
||||
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) FP32
|
||||
sin = sin_cache[positions].unsqueeze(1)
|
||||
out = x.clone()
|
||||
x_rope = x[:, :, nope:].float() # FP32 for accurate rotation
|
||||
out[:, :, nope:] = (x_rope[..., 0::2] * cos - x_rope[..., 1::2] * sin).to(torch.bfloat16)
|
||||
out[:, :, nope:][..., 1::2] = (x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos).to(torch.bfloat16)
|
||||
return out
|
||||
x_even = x_rope[..., 0::2]
|
||||
x_odd = x_rope[..., 1::2]
|
||||
rot_even = x_even * cos - x_odd * sin
|
||||
rot_odd = x_even * sin + x_odd * cos
|
||||
result = x.clone()
|
||||
rope_out = torch.empty_like(x_rope)
|
||||
rope_out[..., 0::2] = rot_even
|
||||
rope_out[..., 1::2] = rot_odd
|
||||
result[:, :, nope:] = rope_out.to(torch.bfloat16)
|
||||
return result
|
||||
|
||||
|
||||
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||||
@@ -201,11 +207,17 @@ def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||||
nope = hd - rope_dim
|
||||
cos = cos_cache[positions].unsqueeze(1)
|
||||
sin = sin_cache[positions].unsqueeze(1)
|
||||
out = o.clone()
|
||||
o_rope = o[:, :, nope:].float()
|
||||
out[:, :, nope:] = (o_rope[..., 0::2] * cos + o_rope[..., 1::2] * sin).to(torch.bfloat16)
|
||||
out[:, :, nope:][..., 1::2] = (-o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos).to(torch.bfloat16)
|
||||
return out
|
||||
o_even = o_rope[..., 0::2]
|
||||
o_odd = o_rope[..., 1::2]
|
||||
inv_even = o_even * cos + o_odd * sin
|
||||
inv_odd = -o_even * sin + o_odd * cos
|
||||
result = o.clone()
|
||||
rope_out = torch.empty_like(o_rope)
|
||||
rope_out[..., 0::2] = inv_even
|
||||
rope_out[..., 1::2] = inv_odd
|
||||
result[:, :, nope:] = rope_out.to(torch.bfloat16)
|
||||
return result
|
||||
|
||||
class SimpleKVCache:
|
||||
"""Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.
|
||||
|
||||
@@ -71,13 +71,17 @@ def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||||
nope = hd - rope_dim
|
||||
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) FP32
|
||||
sin = sin_cache[positions].unsqueeze(1)
|
||||
out = x.clone()
|
||||
# Compute in FP32 for numerical stability
|
||||
x_rope = x[:, :, nope:].float()
|
||||
out[:, :, nope:] = (x_rope[..., 0::2] * cos - x_rope[..., 1::2] * sin).to(torch.bfloat16)
|
||||
# Second pass for odd elements (need original even values)
|
||||
out[:, :, nope:][..., 1::2] = (x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos).to(torch.bfloat16)
|
||||
return out
|
||||
x_rope = x[:, :, nope:].float() # (T, n_h, rope_dim)
|
||||
x_even = x_rope[..., 0::2] # (T, n_h, half)
|
||||
x_odd = x_rope[..., 1::2]
|
||||
rot_even = x_even * cos - x_odd * sin
|
||||
rot_odd = x_even * sin + x_odd * cos
|
||||
result = x.clone()
|
||||
rope_out = torch.empty_like(x_rope)
|
||||
rope_out[..., 0::2] = rot_even
|
||||
rope_out[..., 1::2] = rot_odd
|
||||
result[:, :, nope:] = rope_out.to(torch.bfloat16)
|
||||
return result
|
||||
|
||||
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||||
"""Apply inverse RoPE (conjugate rotation). Computes in FP32 for accuracy."""
|
||||
@@ -85,11 +89,17 @@ def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||||
nope = hd - rope_dim
|
||||
cos = cos_cache[positions].unsqueeze(1)
|
||||
sin = sin_cache[positions].unsqueeze(1)
|
||||
out = o.clone()
|
||||
o_rope = o[:, :, nope:].float()
|
||||
out[:, :, nope:] = (o_rope[..., 0::2] * cos + o_rope[..., 1::2] * sin).to(torch.bfloat16)
|
||||
out[:, :, nope:][..., 1::2] = (-o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos).to(torch.bfloat16)
|
||||
return out
|
||||
o_even = o_rope[..., 0::2]
|
||||
o_odd = o_rope[..., 1::2]
|
||||
inv_even = o_even * cos + o_odd * sin
|
||||
inv_odd = -o_even * sin + o_odd * cos
|
||||
result = o.clone()
|
||||
rope_out = torch.empty_like(o_rope)
|
||||
rope_out[..., 0::2] = inv_even
|
||||
rope_out[..., 1::2] = inv_odd
|
||||
result[:, :, nope:] = rope_out.to(torch.bfloat16)
|
||||
return result
|
||||
|
||||
def load_weights_to_cpu(checkpoint_dir):
|
||||
from safetensors.torch import load_file
|
||||
|
||||
Reference in New Issue
Block a user