127 lines
4.6 KiB
Python
127 lines
4.6 KiB
Python
"""Test _apply_inv_rope_bf16: inverse RoPE should undo forward RoPE."""
|
|
import torch
|
|
import math
|
|
|
|
def apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim):
|
|
"""Forward GPT-J style RoPE."""
|
|
if rope_dim == 0 or x.numel() == 0:
|
|
return x
|
|
half_rot = rope_dim // 2
|
|
x_f32 = x.to(torch.float32)
|
|
cache = cos_sin_cache.index_select(0, positions.to(torch.long))
|
|
cos = cache[:, :half_rot].to(torch.float32)
|
|
sin = cache[:, half_rot:2*half_rot].to(torch.float32)
|
|
view_shape = (positions.shape[0], 1, half_rot)
|
|
cos = cos.view(view_shape)
|
|
sin = sin.view(view_shape)
|
|
rope = x_f32[..., nope_dim:]
|
|
y_even = rope[..., 0::2]
|
|
y_odd = rope[..., 1::2]
|
|
rope_out = torch.stack(
|
|
(y_even * cos - y_odd * sin, y_odd * cos + y_even * sin),
|
|
dim=-1,
|
|
).flatten(-2)
|
|
x_f32 = x_f32.clone()
|
|
x_f32[..., nope_dim:] = rope_out
|
|
return x_f32.to(x.dtype)
|
|
|
|
def apply_inv_rope_bf16(o, positions, cos_sin_cache, nope_dim, rope_dim):
|
|
"""Inverse GPT-J style RoPE (sin -> -sin)."""
|
|
if rope_dim == 0 or o.numel() == 0:
|
|
return o
|
|
half_rot = rope_dim // 2
|
|
o_f32 = o.to(torch.float32)
|
|
cache = cos_sin_cache.index_select(0, positions.to(torch.long))
|
|
cos = cache[:, :half_rot].to(torch.float32)
|
|
sin = cache[:, half_rot:2*half_rot].to(torch.float32)
|
|
view_shape = (positions.shape[0], 1, half_rot)
|
|
cos = cos.view(view_shape)
|
|
sin = sin.view(view_shape)
|
|
rope = o_f32[..., nope_dim:]
|
|
y_even = rope[..., 0::2]
|
|
y_odd = rope[..., 1::2]
|
|
rope_out = torch.stack(
|
|
(y_even * cos + y_odd * sin, y_odd * cos - y_even * sin),
|
|
dim=-1,
|
|
).flatten(-2)
|
|
o_f32 = o_f32.clone()
|
|
o_f32[..., nope_dim:] = rope_out
|
|
return o_f32.to(o.dtype)
|
|
|
|
|
|
def test_inv_rope_roundtrip():
|
|
"""Forward RoPE then inverse RoPE should be identity."""
|
|
torch.manual_seed(42)
|
|
num_tokens = 8
|
|
num_heads = 16
|
|
head_dim = 512
|
|
nope_dim = 448
|
|
rope_dim = 64
|
|
max_pos = 1024
|
|
|
|
# Build cos/sin cache (like RotaryEmbedding)
|
|
half_rot = rope_dim // 2
|
|
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot))
|
|
positions = torch.randint(0, max_pos, (num_tokens,))
|
|
freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0)
|
|
cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
|
|
sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
|
|
cos_vals = torch.cos(freqs)
|
|
sin_vals = torch.sin(freqs)
|
|
for i, p in enumerate(positions):
|
|
cos_cache_full[p] = cos_vals[i]
|
|
sin_cache_full[p] = sin_vals[i]
|
|
cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1)
|
|
|
|
x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16)
|
|
|
|
# Forward RoPE
|
|
x_rope = apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim)
|
|
|
|
# Inverse RoPE
|
|
x_recovered = apply_inv_rope_bf16(x_rope, positions, cos_sin_cache, nope_dim, rope_dim)
|
|
|
|
# Should be identity (within BF16 precision)
|
|
diff = (x.to(torch.float32) - x_recovered.to(torch.float32)).abs().max().item()
|
|
print(f"Max abs diff: {diff:.6f}")
|
|
assert diff < 0.05, f"Roundtrip error too large: {diff}"
|
|
print("PASS: inverse RoPE roundtrip within tolerance")
|
|
|
|
|
|
def test_nope_dim_unchanged():
|
|
"""NoPE dimensions should be unchanged by inverse RoPE."""
|
|
torch.manual_seed(42)
|
|
num_tokens = 4
|
|
num_heads = 4
|
|
head_dim = 128
|
|
nope_dim = 96
|
|
rope_dim = 32
|
|
max_pos = 512
|
|
|
|
half_rot = rope_dim // 2
|
|
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot))
|
|
positions = torch.randint(0, max_pos, (num_tokens,))
|
|
freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0)
|
|
cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
|
|
sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
|
|
cos_vals = torch.cos(freqs)
|
|
sin_vals = torch.sin(freqs)
|
|
for i, p in enumerate(positions):
|
|
cos_cache_full[p] = cos_vals[i]
|
|
sin_cache_full[p] = sin_vals[i]
|
|
cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1)
|
|
|
|
x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16)
|
|
x_inv = apply_inv_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim)
|
|
|
|
# NoPE dims should be unchanged
|
|
nope_diff = (x[..., :nope_dim].to(torch.float32) - x_inv[..., :nope_dim].to(torch.float32)).abs().max().item()
|
|
print(f"NoPE max diff: {nope_diff:.6f}")
|
|
assert nope_diff == 0.0, "NoPE dimensions should be unchanged"
|
|
print("PASS: NoPE dimensions unchanged")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_inv_rope_roundtrip()
|
|
test_nope_dim_unchanged()
|