"""Unit test: wo_a NVFP4 grouped linear + inverse RoPE. Tests the CuTeDSL NVFP4 grouped GEMM that replaces DeepGEMM's fp8_einsum for the wo_a (o-projection first half) in DeepSeek V4 attention. Also tests inverse_rope_bf16 against a synthetic reference. Usage (B200): python3 tests/test_wo_a.py Requires: CuTeDSL, CUDA, Blackwell GPU """ import sys import os import torch import torch.nn.functional as F # Add repo root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from dsv4.ops.rope import inverse_rope_bf16 from dsv4.layers.grouped_linear import Nvfp4GroupedLinear DEVICE = "cuda:0" # DeepSeek V4 Pro dimensions N_LOCAL_GROUPS = 8 HEADS_PER_GROUP = 16 # 128 heads / 8 groups HEAD_DIM = 512 NOPE_DIM = 448 ROPE_DIM = 64 O_LORA_RANK = 1536 GROUP_IN = HEADS_PER_GROUP * HEAD_DIM # 8192 NUM_TOKENS = 4 def test_inverse_rope(): """Test inverse_rope_bf16: apply RoPE then inverse → should recover original.""" print("\n=== Test: inverse_rope_bf16 ===") torch.manual_seed(42) num_tokens = 4 num_heads = N_LOCAL_GROUPS * HEADS_PER_GROUP max_pos = 128 # Build cos_sin_cache (same format as vLLM: cos||sin concatenated) rope_dim = ROPE_DIM half_rope = rope_dim // 2 base = 10000.0 inv_freq = 1.0 / (base ** (torch.arange(0, half_rope, dtype=torch.float32, device=DEVICE) / half_rope)) pos = torch.arange(max_pos, dtype=torch.float32, device=DEVICE) freqs = torch.outer(pos, inv_freq) # (max_pos, half_rope) cos_sin_cache = torch.cat([freqs.cos(), freqs.sin(), ], dim=-1) # (max_pos, rope_dim) # Random attention output o = torch.randn(num_tokens, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=DEVICE) * 2.0 positions = torch.randint(0, max_pos, (num_tokens,), dtype=torch.int64, device=DEVICE) # Apply RoPE (forward), then inverse # Forward RoPE (GPT-J interleaved): o_rope = o[:, :, NOPE_DIM:].clone() cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype) sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype) o_even = o_rope[:, :, 0::2] o_odd = o_rope[:, :, 1::2] rope_even = o_even * cos_all - o_odd * sin_all rope_odd = o_even * sin_all + o_odd * cos_all o_fwd = o.clone() o_fwd[:, :, NOPE_DIM:][:, :, 0::2] = rope_even o_fwd[:, :, NOPE_DIM:][:, :, 1::2] = rope_odd # Apply inverse RoPE o_inv = inverse_rope_bf16(o_fwd, positions, cos_sin_cache, NOPE_DIM, ROPE_DIM) # Compare with original cos = F.cosine_similarity( o.flatten().unsqueeze(0).float(), o_inv.flatten().unsqueeze(0).float() ).item() mse = (o.float() - o_inv.float()).pow(2).mean().item() status = "✅" if cos > 0.999 else "❌" print(f" inverse_rope → original: cosine={cos:.6f} MSE={mse:.6e} {status}") return cos def test_wo_a_grouped_linear(): """Test CuTeDSL NVFP4 wo_a grouped linear against BF16 reference.""" print("\n=== Test: wo_a NVFP4 Grouped Linear ===") torch.manual_seed(42) num_tokens = NUM_TOKENS # Random attention output (after inverse RoPE) o = torch.randn(num_tokens, N_LOCAL_GROUPS * HEADS_PER_GROUP, HEAD_DIM, dtype=torch.bfloat16, device=DEVICE) * 2.0 # Random wo_a weight (BF16, grouped format) # In vLLM, wo_a is ColumnParallelLinear with is_bmm=True # Weight shape: (n_local_groups, heads_per_group * head_dim, o_lora_rank) wo_a_weight = torch.randn( N_LOCAL_GROUPS, GROUP_IN, O_LORA_RANK, dtype=torch.bfloat16, device=DEVICE ) * 0.1 # BF16 reference: grouped matmul o_grouped = o.reshape(num_tokens, N_LOCAL_GROUPS, GROUP_IN) z_ref = torch.empty(num_tokens, N_LOCAL_GROUPS, O_LORA_RANK, dtype=torch.bfloat16, device=DEVICE) for g in range(N_LOCAL_GROUPS): # (tokens, GROUP_IN) × (GROUP_IN, O_LORA_RANK) → (tokens, O_LORA_RANK) z_ref[:, g, :] = o_grouped[:, g, :] @ wo_a_weight[g] # CuTeDSL NVFP4 runner runner = Nvfp4GroupedLinear( n_local_groups=N_LOCAL_GROUPS, heads_per_group=HEADS_PER_GROUP, head_dim=HEAD_DIM, o_lora_rank=O_LORA_RANK, max_num_tokens=8192, device=DEVICE, ) runner.set_bf16_weight(wo_a_weight) runner.finalize_weights() # Warmup + compute activation global scale runner._ensure_initialized() runner.compute_activation_global_scale(o) # Run with torch.no_grad(): z_out = runner.run(o) # Compare cos = F.cosine_similarity( z_ref.flatten().unsqueeze(0).float(), z_out.flatten().unsqueeze(0).float() ).item() mse = (z_ref.float() - z_out.float()).pow(2).mean().item() status = "✅" if cos >= 0.98 else "❌" print(f" wo_a grouped linear: cosine={cos:.6f} MSE={mse:.6e} {status}") print(f" z_ref amax={z_ref.amax():.4f} z_out amax={z_out.amax():.4f}") return cos def main(): torch.cuda.set_device(0) print("=== wo_a NVFP4 Grouped Linear + Inverse RoPE Tests ===") cos_rope = test_inverse_rope() cos_woa = test_wo_a_grouped_linear() print(f"\n=== SUMMARY ===") results = {"inverse_rope": cos_rope, "wo_a_grouped_linear": cos_woa} all_pass = True for name, cos in results.items(): threshold = 0.999 if name == "inverse_rope" else 0.98 status = "✅" if cos >= threshold else "❌" if cos < threshold: all_pass = False print(f" {name}: cosine={cos:.6f} {status}") if all_pass: print("\n✅ ALL PASS") else: print("\n❌ SOME FAILED") if __name__ == "__main__": main()