"""Test _apply_inv_rope_bf16: inverse RoPE should undo forward RoPE.""" import torch import math def apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim): """Forward GPT-J style RoPE.""" if rope_dim == 0 or x.numel() == 0: return x half_rot = rope_dim // 2 x_f32 = x.to(torch.float32) cache = cos_sin_cache.index_select(0, positions.to(torch.long)) cos = cache[:, :half_rot].to(torch.float32) sin = cache[:, half_rot:2*half_rot].to(torch.float32) view_shape = (positions.shape[0], 1, half_rot) cos = cos.view(view_shape) sin = sin.view(view_shape) rope = x_f32[..., nope_dim:] y_even = rope[..., 0::2] y_odd = rope[..., 1::2] rope_out = torch.stack( (y_even * cos - y_odd * sin, y_odd * cos + y_even * sin), dim=-1, ).flatten(-2) x_f32 = x_f32.clone() x_f32[..., nope_dim:] = rope_out return x_f32.to(x.dtype) def apply_inv_rope_bf16(o, positions, cos_sin_cache, nope_dim, rope_dim): """Inverse GPT-J style RoPE (sin -> -sin).""" if rope_dim == 0 or o.numel() == 0: return o half_rot = rope_dim // 2 o_f32 = o.to(torch.float32) cache = cos_sin_cache.index_select(0, positions.to(torch.long)) cos = cache[:, :half_rot].to(torch.float32) sin = cache[:, half_rot:2*half_rot].to(torch.float32) view_shape = (positions.shape[0], 1, half_rot) cos = cos.view(view_shape) sin = sin.view(view_shape) rope = o_f32[..., nope_dim:] y_even = rope[..., 0::2] y_odd = rope[..., 1::2] rope_out = torch.stack( (y_even * cos + y_odd * sin, y_odd * cos - y_even * sin), dim=-1, ).flatten(-2) o_f32 = o_f32.clone() o_f32[..., nope_dim:] = rope_out return o_f32.to(o.dtype) def test_inv_rope_roundtrip(): """Forward RoPE then inverse RoPE should be identity.""" torch.manual_seed(42) num_tokens = 8 num_heads = 16 head_dim = 512 nope_dim = 448 rope_dim = 64 max_pos = 1024 # Build cos/sin cache (like RotaryEmbedding) half_rot = rope_dim // 2 inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot)) positions = torch.randint(0, max_pos, (num_tokens,)) freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0) cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32) sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32) cos_vals = torch.cos(freqs) sin_vals = torch.sin(freqs) for i, p in enumerate(positions): cos_cache_full[p] = cos_vals[i] sin_cache_full[p] = sin_vals[i] cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1) x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16) # Forward RoPE x_rope = apply_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim) # Inverse RoPE x_recovered = apply_inv_rope_bf16(x_rope, positions, cos_sin_cache, nope_dim, rope_dim) # Should be identity (within BF16 precision) diff = (x.to(torch.float32) - x_recovered.to(torch.float32)).abs().max().item() print(f"Max abs diff: {diff:.6f}") assert diff < 0.05, f"Roundtrip error too large: {diff}" print("PASS: inverse RoPE roundtrip within tolerance") def test_nope_dim_unchanged(): """NoPE dimensions should be unchanged by inverse RoPE.""" torch.manual_seed(42) num_tokens = 4 num_heads = 4 head_dim = 128 nope_dim = 96 rope_dim = 32 max_pos = 512 half_rot = rope_dim // 2 inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half_rot, dtype=torch.float32) / half_rot)) positions = torch.randint(0, max_pos, (num_tokens,)) freqs = positions.float().unsqueeze(1) * inv_freq.unsqueeze(0) cos_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32) sin_cache_full = torch.zeros(max_pos, half_rot, dtype=torch.float32) cos_vals = torch.cos(freqs) sin_vals = torch.sin(freqs) for i, p in enumerate(positions): cos_cache_full[p] = cos_vals[i] sin_cache_full[p] = sin_vals[i] cos_sin_cache = torch.cat([cos_cache_full, sin_cache_full], dim=1) x = torch.randn(num_tokens, num_heads, head_dim, dtype=torch.bfloat16) x_inv = apply_inv_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim) # NoPE dims should be unchanged nope_diff = (x[..., :nope_dim].to(torch.float32) - x_inv[..., :nope_dim].to(torch.float32)).abs().max().item() print(f"NoPE max diff: {nope_diff:.6f}") assert nope_diff == 0.0, "NoPE dimensions should be unchanged" print("PASS: NoPE dimensions unchanged") if __name__ == "__main__": test_inv_rope_roundtrip() test_nope_dim_unchanged()