The original attention forward uses fused_inv_rope_fp8_quant + deepseek_v4_fp8_einsum which requires wo_a to have FP8 weights and weight_scale_inv. Our checkpoint has wo_a in BF16, so the original path crashes (produces empty output). Replace O projection with: 1. _apply_inv_rope_bf16: pure PyTorch inverse RoPE (no FP8) 2. BMM grouped linear for wo_a (BF16) 3. NVFP4 wo_b via CuTeDSL Also fixes activation global scale bug from previous commit: - input_global_scale_inv IS the activation gs, don't re-invert - w13_input_scale_orig (after undoing convert) IS the MoE gs Test: tests/test_o_projection.py validates inv RoPE roundtrip and wo_a BMM correctness.
160 lines
5.7 KiB
Python
160 lines
5.7 KiB
Python
"""Test BF16 inverse RoPE + wo_a BMM (no GPU needed).
|
|
|
|
Validates the O projection path we patched into the attention forward.
|
|
"""
|
|
|
|
import torch
|
|
import math
|
|
|
|
|
|
def apply_inv_rope_bf16(
|
|
o: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
nope_dim: int = 448,
|
|
rope_dim: int = 64,
|
|
) -> torch.Tensor:
|
|
"""Same as the patched version in deepseek_v4_attention.py."""
|
|
if rope_dim == 0 or o.numel() == 0:
|
|
return o
|
|
half_rope = rope_dim // 2
|
|
|
|
cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype)
|
|
sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype)
|
|
|
|
o_rope = o[:, :, nope_dim:]
|
|
o_even = o_rope[:, :, 0::2]
|
|
o_odd = o_rope[:, :, 1::2]
|
|
|
|
inv_even = o_even * cos_all + o_odd * sin_all
|
|
inv_odd = -o_even * sin_all + o_odd * cos_all
|
|
|
|
result = o.clone()
|
|
result[:, :, nope_dim:][:, :, 0::2] = inv_even
|
|
result[:, :, nope_dim:][:, :, 1::2] = inv_odd
|
|
return result
|
|
|
|
|
|
def apply_gptj_rope(
|
|
x: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
cos_sin_cache: torch.Tensor,
|
|
nope_dim: int = 448,
|
|
rope_dim: int = 64,
|
|
) -> torch.Tensor:
|
|
"""Apply forward GPT-J style RoPE (for testing roundtrip)."""
|
|
half_rope = rope_dim // 2
|
|
cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(x.dtype)
|
|
sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(x.dtype)
|
|
|
|
x_rope = x[:, :, nope_dim:]
|
|
x_even = x_rope[:, :, 0::2]
|
|
x_odd = x_rope[:, :, 1::2]
|
|
|
|
rot_even = x_even * cos_all - x_odd * sin_all
|
|
rot_odd = x_even * sin_all + x_odd * cos_all
|
|
|
|
result = x.clone()
|
|
result[:, :, nope_dim:][:, :, 0::2] = rot_even
|
|
result[:, :, nope_dim:][:, :, 1::2] = rot_odd
|
|
return result
|
|
|
|
|
|
def test_inv_rope_roundtrip():
|
|
"""inv_rope(forward_rope(x)) should recover x."""
|
|
torch.manual_seed(42)
|
|
T, H, D = 4, 8, 512 # tokens, heads, head_dim
|
|
nope_dim, rope_dim = 448, 64
|
|
max_pos = 100
|
|
|
|
# Build cos_sin_cache for positions 0..max_pos
|
|
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, rope_dim, 2).float() / rope_dim))
|
|
t = torch.arange(max_pos, dtype=torch.float32)
|
|
freqs = torch.einsum("i,j -> ij", t, inv_freq) # (max_pos, half_rope)
|
|
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (max_pos, rope_dim)
|
|
|
|
x = torch.randn(T, H, D, dtype=torch.bfloat16) * 0.1
|
|
positions = torch.tensor([0, 5, 10, 50], dtype=torch.int64)
|
|
|
|
# Apply forward RoPE, then inverse
|
|
rotated = apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim)
|
|
recovered = apply_inv_rope_bf16(rotated, positions, cos_sin_cache, nope_dim, rope_dim)
|
|
|
|
# NoPE portion unchanged
|
|
nope_diff = (recovered[:, :, :nope_dim] - x[:, :, :nope_dim]).abs().max().item()
|
|
assert nope_diff == 0, f"NoPE should be unchanged, max diff: {nope_diff}"
|
|
|
|
# RoPE portion should roundtrip within BF16 precision
|
|
rope_diff = (recovered[:, :, nope_dim:] - x[:, :, nope_dim:]).abs().max().item()
|
|
assert rope_diff < 0.02, f"RoPE roundtrip error too high: {rope_diff}"
|
|
print(f"✅ inv_rope roundtrip: NoPE diff={nope_diff}, RoPE diff={rope_diff:.6f}")
|
|
|
|
|
|
def test_wo_a_bmm():
|
|
"""wo_a BMM should match einsum 'tgd,grd->tgr'."""
|
|
torch.manual_seed(42)
|
|
T = 3
|
|
n_local_groups = 4
|
|
heads_per_group = 2
|
|
head_dim = 512
|
|
o_lora_rank = 128
|
|
n_local_heads = n_local_groups * heads_per_group
|
|
|
|
# wo_a weight: (n_groups * o_lora_rank, heads_per_group * head_dim)
|
|
wo_a_weight = torch.randn(n_local_groups * o_lora_rank, heads_per_group * head_dim, dtype=torch.bfloat16)
|
|
|
|
# Attention output (after inv RoPE): (T, n_local_heads, head_dim)
|
|
o_inv = torch.randn(T, n_local_heads, head_dim, dtype=torch.bfloat16)
|
|
|
|
# BMM path (our implementation)
|
|
hidden_dim = heads_per_group * head_dim
|
|
o_grouped = o_inv.view(T, n_local_groups, hidden_dim)
|
|
wo_a_w = wo_a_weight.view(n_local_groups, o_lora_rank, hidden_dim)
|
|
z_bmm = torch.bmm(
|
|
o_grouped.permute(1, 0, 2),
|
|
wo_a_w.transpose(1, 2),
|
|
).permute(1, 0, 2)
|
|
|
|
# Reference: einsum
|
|
o_for_einsum = o_inv.view(T, n_local_groups, hidden_dim).float()
|
|
wo_a_for_einsum = wo_a_w.float()
|
|
z_einsum = torch.einsum("tgd,grd->tgr", o_for_einsum, wo_a_for_einsum).bfloat16()
|
|
|
|
diff = (z_bmm - z_einsum).abs().max().item()
|
|
assert diff < 0.01, f"wo_a BMM vs einsum diff: {diff}"
|
|
print(f"✅ wo_a BMM matches einsum: max diff={diff:.6f}")
|
|
|
|
|
|
def test_inv_rope_at_zero():
|
|
"""At position 0, cos=1, sin=0, so inv_rope should be identity on RoPE dims."""
|
|
torch.manual_seed(42)
|
|
T, H, D = 2, 4, 512
|
|
nope_dim, rope_dim = 448, 64
|
|
|
|
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, rope_dim, 2).float() / rope_dim))
|
|
t = torch.arange(10, dtype=torch.float32)
|
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
|
cos_sin_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (10, rope_dim)
|
|
# At pos 0, cos=1, sin=0
|
|
|
|
x = torch.randn(T, H, D, dtype=torch.bfloat16) * 0.1
|
|
positions = torch.zeros(T, dtype=torch.int64)
|
|
|
|
# Forward RoPE at pos 0 should be identity (cos=1, sin=0)
|
|
rotated = apply_gptj_rope(x, positions, cos_sin_cache, nope_dim, rope_dim)
|
|
diff = (rotated - x).abs().max().item()
|
|
assert diff < 1e-5, f"RoPE at pos=0 should be identity, diff={diff}"
|
|
|
|
# Inverse RoPE on unrotated input at pos 0 should also be identity
|
|
inv = apply_inv_rope_bf16(x, positions, cos_sin_cache, nope_dim, rope_dim)
|
|
diff2 = (inv - x).abs().max().item()
|
|
assert diff2 < 1e-5, f"inv RoPE at pos=0 should be identity, diff={diff2}"
|
|
print(f"✅ inv_rope at pos=0 is identity (diff={diff2:.8f})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_inv_rope_roundtrip()
|
|
test_wo_a_bmm()
|
|
test_inv_rope_at_zero()
|
|
print("\n✅ All attention O-projection tests passed")
|