Files
nvfp4-megamoe-kernel/tests/test_inv_rope.py

127 lines
4.6 KiB
Python

"""Test _apply_inv_rope_bf16: inverse RoPE should undo forward RoPE."""
import torch
import math
def apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim):
"""Forward GPT-J style RoPE."""
if rope_dim == 0 or x.numel() == 0:
return x
half_rot = rope_dim // 2
x_f32 = x.to(torch.float32)
cache = cos_sin_cache.index_select(0, positions.to(torch.long))
cos = cache[:, :half_rot].to(torch.float32)
sin = cache[:, half_rot:2*half_rot].to(torch.float32)
view_shape = (positions.shape[0], 1, half_rot)
cos = cos.view(view_shape)
sin = sin.view(view_shape)
rope = x_f32[..., nope_dim:]
y_even = rope[..., 0::2]
y_odd = rope[..., 1::2]
rope_out = torch.stack(
(y_even * cos - y_odd * sin, y_odd * cos + y_even * sin),
dim=-1,
).flatten(-2)
x_f32 = x_f32.clone()
x_f32[..., nope_dim:] = rope_out
return x_f32.to(x.dtype)
def apply_inv_rope_bf16(o, positions, cos_sin_cache, nope_dim, rope_dim):
"""Inverse GPT-J style RoPE (sin -> -sin)."""
if rope_dim == 0 or o.numel() == 0:
return o
half_rot = rope_dim // 2
o_f32 = o.to(torch.float32)
cache = cos_sin_cache.index_select(0, positions.to(torch.long))
cos = cache[:, :half_rot].to(torch.float32)
sin = cache[:, half_rot:2*half_rot].to(torch.float32)
view_shape = (positions.shape[0], 1, half_rot)
cos = cos.view(view_shape)
sin = sin.view(view_shape)
rope = o_f32[..., nope_dim:]
y_even = rope[..., 0::2]
y_odd = rope[..., 1::2]
rope_out = torch.stack(
(y_even * cos + y_odd * sin, y_odd * cos - y_even * sin),
dim=-1,
).flatten(-2)
o_f32 = o_f32.clone()
o_f32[..., nope_dim:] = rope_out
return o_f32.to(o.dtype)
def test_inv_rope_roundtrip():
"""Forward RoPE then inverse RoPE should be identity."""
torch.manual_seed(42)
num_tokens = 8
num_heads = 16
head_dim = 512
nope_dim = 448
rope_dim = 64
max_pos = 1024
# Build cos/sin cache (like RotaryEmbedding)
half_rot = rope_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot))
positions = torch.randint(0, max_pos, (num_tokens,))
freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0)
cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
cos_vals = torch.cos(freqs)
sin_vals = torch.sin(freqs)
for i, p in enumerate(positions):
cos_cache_full[p] = cos_vals[i]
sin_cache_full[p] = sin_vals[i]
cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1)
x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16)
# Forward RoPE
x_rope = apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim)
# Inverse RoPE
x_recovered = apply_inv_rope_bf16(x_rope, positions, cos_sin_cache, nope_dim, rope_dim)
# Should be identity (within BF16 precision)
diff = (x.to(torch.float32) - x_recovered.to(torch.float32)).abs().max().item()
print(f"Max abs diff: {diff:.6f}")
assert diff < 0.05, f"Roundtrip error too large: {diff}"
print("PASS: inverse RoPE roundtrip within tolerance")
def test_nope_dim_unchanged():
"""NoPE dimensions should be unchanged by inverse RoPE."""
torch.manual_seed(42)
num_tokens = 4
num_heads = 4
head_dim = 128
nope_dim = 96
rope_dim = 32
max_pos = 512
half_rot = rope_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot))
positions = torch.randint(0, max_pos, (num_tokens,))
freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0)
cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32)
cos_vals = torch.cos(freqs)
sin_vals = torch.sin(freqs)
for i, p in enumerate(positions):
cos_cache_full[p] = cos_vals[i]
sin_cache_full[p] = sin_vals[i]
cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1)
x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16)
x_inv = apply_inv_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim)
# NoPE dims should be unchanged
nope_diff = (x[..., :nope_dim].to(torch.float32) - x_inv[..., :nope_dim].to(torch.float32)).abs().max().item()
print(f"NoPE max diff: {nope_diff:.6f}")
assert nope_diff == 0.0, "NoPE dimensions should be unchanged"
print("PASS: NoPE dimensions unchanged")
if __name__ == "__main__":
test_inv_rope_roundtrip()
test_nope_dim_unchanged()