"""Test BF16 inverse RoPE + wo_a BMM (no GPU needed). Validates the O projection path we patched into the attention forward. """ import torch import math def apply_inv_rope_bf16( o: torch.Tensor, positions: torch.Tensor, cos_sin_cache: torch.Tensor, nope_dim: int = 448, rope_dim: int = 64, ) -> torch.Tensor: """Same as the patched version in deepseek_v4_attention.py.""" if rope_dim == 0 or o.numel() == 0: return o half_rope = rope_dim // 2 cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype) sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype) o_rope = o[:, :, nope_dim:] o_even = o_rope[:, :, 0::2] o_odd = o_rope[:, :, 1::2] inv_even = o_even * cos_all + o_odd * sin_all inv_odd = -o_even * sin_all + o_odd * cos_all result = o.clone() result[:, :, nope_dim:][:, :, 0::2] = inv_even result[:, :, nope_dim:][:, :, 1::2] = inv_odd return result def apply_gptj_rope( x: torch.Tensor, positions: torch.Tensor, cos_sin_cache: torch.Tensor, nope_dim: int = 448, rope_dim: int = 64, ) -> torch.Tensor: """Apply forward GPT-J style RoPE (for testing roundtrip).""" half_rope = rope_dim // 2 cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(x.dtype) sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(x.dtype) x_rope = x[:, :, nope_dim:] x_even = x_rope[:, :, 0::2] x_odd = x_rope[:, :, 1::2] rot_even = x_even * cos_all - x_odd * sin_all rot_odd = x_even * sin_all + x_odd * cos_all result = x.clone() result[:, :, nope_dim:][:, :, 0::2] = rot_even result[:, :, nope_dim:][:, :, 1::2] = rot_odd return result def test_inv_rope_roundtrip(): """inv_rope(forward_rope(x)) should recover x.""" torch.manual_seed(42) T, H, D = 4, 8, 512 # tokens, heads, head_dim nope_dim, rope_dim = 448, 64 max_pos = 100 # Build cos_sin_cache for positions 0..max_pos inv_freq = 1.0 / (10000.0 ** (torch.arange(0, rope_dim, 2).float() / rope_dim)) t = torch.arange(max_pos, dtype=torch.float32) freqs = torch.einsum("i,j -> ij", t, inv_freq) # (max_pos, half_rope) cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (max_pos, rope_dim) x = torch.randn(T, H, D, dtype=torch.bfloat16) * 0.1 positions = torch.tensor([0, 5, 10, 50], dtype=torch.int64) # Apply forward RoPE, then inverse rotated = apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim) recovered = apply_inv_rope_bf16(rotated, positions, cos_sin_cache, nope_dim, rope_dim) # NoPE portion unchanged nope_diff = (recovered[:, :, :nope_dim] - x[:, :, :nope_dim]).abs().max().item() assert nope_diff == 0, f"NoPE should be unchanged, max diff: {nope_diff}" # RoPE portion should roundtrip within BF16 precision rope_diff = (recovered[:, :, nope_dim:] - x[:, :, nope_dim:]).abs().max().item() assert rope_diff < 0.02, f"RoPE roundtrip error too high: {rope_diff}" print(f"āœ… inv_rope roundtrip: NoPE diff={nope_diff}, RoPE diff={rope_diff:.6f}") def test_wo_a_bmm(): """wo_a BMM should match einsum 'tgd,grd->tgr'.""" torch.manual_seed(42) T = 3 n_local_groups = 4 heads_per_group = 2 head_dim = 512 o_lora_rank = 128 n_local_heads = n_local_groups * heads_per_group # wo_a weight: (n_groups * o_lora_rank, heads_per_group * head_dim) wo_a_weight = torch.randn(n_local_groups * o_lora_rank, heads_per_group * head_dim, dtype=torch.bfloat16) # Attention output (after inv RoPE): (T, n_local_heads, head_dim) o_inv = torch.randn(T, n_local_heads, head_dim, dtype=torch.bfloat16) # BMM path (our implementation) hidden_dim = heads_per_group * head_dim o_grouped = o_inv.view(T, n_local_groups, hidden_dim) wo_a_w = wo_a_weight.view(n_local_groups, o_lora_rank, hidden_dim) z_bmm = torch.bmm( o_grouped.permute(1, 0, 2), wo_a_w.transpose(1, 2), ).permute(1, 0, 2) # Reference: einsum o_for_einsum = o_inv.view(T, n_local_groups, hidden_dim).float() wo_a_for_einsum = wo_a_w.float() z_einsum = torch.einsum("tgd,grd->tgr", o_for_einsum, wo_a_for_einsum).bfloat16() diff = (z_bmm - z_einsum).abs().max().item() assert diff < 0.01, f"wo_a BMM vs einsum diff: {diff}" print(f"āœ… wo_a BMM matches einsum: max diff={diff:.6f}") def test_inv_rope_at_zero(): """At position 0, cos=1, sin=0, so inv_rope should be identity on RoPE dims.""" torch.manual_seed(42) T, H, D = 2, 4, 512 nope_dim, rope_dim = 448, 64 inv_freq = 1.0 / (10000.0 ** (torch.arange(0, rope_dim, 2).float() / rope_dim)) t = torch.arange(10, dtype=torch.float32) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (10, rope_dim) # At pos 0, cos=1, sin=0 x = torch.randn(T, H, D, dtype=torch.bfloat16) * 0.1 positions = torch.zeros(T, dtype=torch.int64) # Forward RoPE at pos 0 should be identity (cos=1, sin=0) rotated = apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim) diff = (rotated - x).abs().max().item() assert diff < 1e-5, f"RoPE at pos=0 should be identity, diff={diff}" # Inverse RoPE on unrotated input at pos 0 should also be identity inv = apply_inv_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim) diff2 = (inv - x).abs().max().item() assert diff2 < 1e-5, f"inv RoPE at pos=0 should be identity, diff={diff2}" print(f"āœ… inv_rope at pos=0 is identity (diff={diff2:.8f})") if __name__ == "__main__": test_inv_rope_roundtrip() test_wo_a_bmm() test_inv_rope_at_zero() print("\nāœ… All attention O-projection tests passed")