#!/usr/bin/env python3 """ B200 test: Proves the attention O-projection root cause and fix. Loads real model weights from the NVFP4 checkpoint and tests: 1. OLD path: fused_inv_rope_fp8_quant + FP8 einsum (crashes with BF16 wo_a) 2. NEW path: BF16 inv RoPE + BMM wo_a + NVFP4 wo_b (should work) Also tests the NVFP4 linear kernel (wo_b) with real weights. Usage (on B200): python3 tests/test_o_projection_b200.py Requires: Real model weights at /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4 CuTeDSL, CUDA, Blackwell GPU """ import sys import os import json import torch import torch.nn.functional as F from safetensors import safe_open MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" DEVICE = "cuda:0" LAYER_IDX = 0 # DeepSeek V4 Pro dimensions HIDDEN_SIZE = 7168 NUM_HEADS = 128 HEAD_DIM = 512 NOPE_DIM = 448 ROPE_DIM = 64 Q_LORA_RANK = 1536 O_LORA_RANK = 1024 O_GROUPS = 16 # from config (not TP-sharded) HEADS_PER_GROUP = NUM_HEADS // O_GROUPS # 8 NUM_TOKENS = 4 _cache = {} def load_tensor(key, wm, model_dir): if key in _cache: return _cache[key] shard_path = os.path.join(model_dir, wm[key]) with safe_open(shard_path, framework="pt") as f: t = f.get_tensor(key) _cache[key] = t return t # ── OLD PATH: What the unpatched vLLM forward does ────────────────── def old_path_o_projection(o, positions, cos_sin_cache, wo_a_weight_bf16): """Simulates the OLD (broken) attention forward. The old path does: o_fp8, o_scale = fused_inv_rope_fp8_quant(o, ...) wo_a_scale = self.wo_a.weight_scale_inv ← DOESN'T EXIST on BF16 wo_a deepseek_v4_fp8_einsum(o_fp8, o_scale, wo_a_fp8, wo_a_scale, z, ...) Since wo_a is BF16 (no weight_scale_inv), this crashes. We simulate the crash by showing weight_scale_inv doesn't exist. """ has_scale_inv = hasattr(wo_a_weight_bf16, 'weight_scale_inv') or \ (isinstance(wo_a_weight_bf16, torch.Tensor) and False) # The weight is BF16 — FP8 einsum can't use it # The old code does: wo_a_fp8 = self.wo_a.weight (BF16!) # Then: wo_a_scale = self.wo_a.weight_scale_inv (AttributeError!) # Simulate what would happen: FP8 einsum with BF16 weight # If we naively try to quantize BF16 to FP8 without proper scales... print(" OLD PATH: wo_a.weight is BF16 (shape={}, dtype={})".format( wo_a_weight_bf16.shape, wo_a_weight_bf16.dtype)) # Try to access weight_scale_inv (this is what the old code does) if isinstance(wo_a_weight_bf16, torch.Tensor): # This is what vLLM does: self.wo_a.weight_scale_inv # Since wo_a is a plain BF16 tensor, this AttributeError crashes the worker try: _ = wo_a_weight_bf16.weight_scale_inv print(" ❌ UNEXPECTED: weight_scale_inv exists (shouldn't for BF16)") except AttributeError: print(" ✅ CONFIRMED: weight_scale_inv does NOT exist → AttributeError in vLLM") print(" This is the root cause: the FP8 einsum path crashes because") print(" wo_a has quant_config=None (BF16) but the forward expects FP8.") return None # Can't produce valid output # ── NEW PATH: Our patched forward ─────────────────────────────────── def apply_inv_rope_bf16(o, positions, cos_sin_cache, nope_dim=NOPE_DIM, rope_dim=ROPE_DIM): """BF16 inverse RoPE (pure PyTorch).""" if rope_dim == 0 or o.numel() == 0: return o half_rope = rope_dim // 2 cos_all = cos_sin_cache[positions, :half_rope].unsqueeze(1).to(o.dtype) sin_all = cos_sin_cache[positions, half_rope:].unsqueeze(1).to(o.dtype) o_rope = o[:, :, nope_dim:] o_even = o_rope[:, :, 0::2] o_odd = o_rope[:, :, 1::2] inv_even = o_even * cos_all + o_odd * sin_all inv_odd = -o_even * sin_all + o_odd * cos_all result = o.clone() result[:, :, nope_dim:][:, :, 0::2] = inv_even result[:, :, nope_dim:][:, :, 1::2] = inv_odd return result def new_path_o_projection(o, positions, cos_sin_cache, wo_a_weight_bf16): """NEW path: BF16 inv RoPE + BMM wo_a. Returns z of shape (T, n_local_groups, o_lora_rank) ready for wo_b. """ # Step 1: Inverse RoPE (BF16) o_inv = apply_inv_rope_bf16(o, positions, cos_sin_cache) print(f" Inverse RoPE: shape={o_inv.shape} amax={o_inv.amax():.4f} NaN={torch.isnan(o_inv).any()}") # Step 2: wo_a BMM num_tokens = o_inv.shape[0] # wo_a weight: (O_GROUPS * O_LORA_RANK, HEADS_PER_GROUP * HEAD_DIM) hidden_dim = wo_a_weight_bf16.shape[1] # 4096 = HEADS_PER_GROUP * HEAD_DIM out_dim = wo_a_weight_bf16.shape[0] # 16384 = O_GROUPS * O_LORA_RANK o_grouped = o_inv.view(num_tokens, O_GROUPS, hidden_dim) wo_a_w = wo_a_weight_bf16.view(O_GROUPS, O_LORA_RANK, hidden_dim) z = torch.bmm( o_grouped.permute(1, 0, 2), wo_a_w.transpose(1, 2), ).permute(1, 0, 2) print(f" wo_a BMM: shape={z.shape} amax={z.amax():.4f} NaN={torch.isnan(z).any()}") return z # ── NVFP4 wo_b test ───────────────────────────────────────────────── E2M1_LUT = torch.tensor([ 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, ], dtype=torch.float32) def dequant_nvfp4(packed_uint8, scale_e4m3, global_scale): """Dequantize NVFP4 weight to BF16 for reference.""" device = packed_uint8.device lut = E2M1_LUT.to(device) lower = lut[(packed_uint8 & 0x0F).long()] upper = lut[((packed_uint8 >> 4) & 0x0F).long()] out_features = packed_uint8.shape[0] in_features = packed_uint8.shape[1] * 2 unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device) unpacked[:, 0::2] = lower unpacked[:, 1::2] = upper block_scale = scale_e4m3.float() block_expanded = block_scale.repeat_interleave(16, dim=1)[:out_features, :in_features] return (unpacked * block_expanded * global_scale).to(torch.bfloat16) def test_wo_b_nvfp4(z, wo_b_weight, wo_b_sf, wo_b_gs): """Test wo_b NVFP4 GEMM against BF16 reference.""" sys.path.insert(0, "/root/nvfp4-megamoe-kernel") from dsv4.layers.linear import Nvfp4Linear in_features = wo_b_weight.shape[1] * 2 out_features = wo_b_weight.shape[0] # Convert to CuTeDSL format fp4 = [wo_b_weight.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()] sf = [wo_b_sf.permute(1, 0).contiguous()] gs = [wo_b_gs] runner = Nvfp4Linear( in_features=in_features, out_features=out_features, max_num_tokens=8192, device=DEVICE, ) runner.fp4 = fp4 runner.sf = sf runner.gs = gs runner.finalize_weights() runner._ensure_initialized() # Warmup: compute activation global scale z_flat = z.flatten(1) # (T, O_GROUPS * O_LORA_RANK) runner.compute_activation_global_scale(z_flat) # Run CuTeDSL with torch.no_grad(): output = runner.run(z_flat) print(f" wo_b CuTeDSL: shape={output.shape} amax={output.amax():.4f} NaN={torch.isnan(output).any()}") # BF16 reference bf16_w = dequant_nvfp4(wo_b_weight, wo_b_sf, wo_b_gs) with torch.no_grad(): ref = z_flat @ bf16_w.T print(f" wo_b BF16 ref: shape={ref.shape} amax={ref.amax():.4f}") cos = F.cosine_similarity(ref.flatten().unsqueeze(0), output.flatten().unsqueeze(0)).item() mse = (ref - output).pow(2).mean().item() status = "✅" if cos >= 0.98 else "❌" print(f" wo_b cosine={cos:.6f} MSE={mse:.6e} {status}") return cos def build_cos_sin_cache(max_pos=4096, rope_dim=ROPE_DIM): """Build cos_sin_cache in the same format as vLLM's RotaryEmbedding.""" half_rope = rope_dim // 2 base = 10000.0 inv_freq = 1.0 / (base ** (torch.arange(0, half_rope, dtype=torch.float32) / half_rope)) t = torch.arange(max_pos, dtype=torch.float32) freqs = torch.outer(t, inv_freq) # (max_pos, half_rope) return torch.cat([freqs.cos(), freqs.sin()], dim=-1) # (max_pos, rope_dim) def main(): torch.cuda.set_device(0) torch.manual_seed(42) print("=" * 70) print(" B200 Test: O-Projection Root Cause + Fix") print("=" * 70) # Load weight map with open(os.path.join(MODEL_PATH, "model.safetensors.index.json")) as f: wm = json.load(f)["weight_map"] P = lambda key: load_tensor(key, wm, MODEL_PATH).to(DEVICE) prefix = f"model.layers.{LAYER_IDX}.self_attn" # Load wo_a (BF16) and wo_b (NVFP4) weights print("\n--- Loading weights ---") wo_a_w = P(f"{prefix}.o_a_proj.weight") print(f" wo_a: shape={wo_a_w.shape} dtype={wo_a_w.dtype}") wo_b_w = P(f"{prefix}.o_b_proj.weight") wo_b_sf = P(f"{prefix}.o_b_proj.weight_scale") wo_b_gs = P(f"{prefix}.o_b_proj.weight_scale_2").item() print(f" wo_b: shape={wo_b_w.shape} dtype={wo_b_w.dtype} gs={wo_b_gs:.8f}") # Check: wo_a should NOT have weight_scale_inv # (it's a plain BF16 tensor, not a quantized layer) # Build cos_sin_cache cos_sin_cache = build_cos_sin_cache().to(DEVICE) # Simulate attention output (what FlashMLA would produce) print("\n--- Simulating attention output ---") positions = torch.tensor([0, 1, 2, 3], dtype=torch.int64, device=DEVICE) o = torch.randn(NUM_TOKENS, NUM_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=DEVICE) * 0.1 print(f" Attention output: shape={o.shape} amax={o.amax():.4f}") # ═══════════════════════════════════════════════════════════════════ # TEST 1: OLD PATH (should show crash/AttributeError) # ═══════════════════════════════════════════════════════════════════ print("\n" + "=" * 70) print(" TEST 1: OLD PATH (FP8 einsum — should crash)") print("=" * 70) old_path_o_projection(o, positions, cos_sin_cache, wo_a_w) # ═══════════════════════════════════════════════════════════════════ # TEST 2: NEW PATH (BF16 inv RoPE + BMM wo_a — should work) # ═══════════════════════════════════════════════════════════════════ print("\n" + "=" * 70) print(" TEST 2: NEW PATH (BF16 inv RoPE + BMM wo_a)") print("=" * 70) z = new_path_o_projection(o, positions, cos_sin_cache, wo_a_w) if z is not None and not torch.isnan(z).any(): # ═══════════════════════════════════════════════════════════════ # TEST 3: wo_b NVFP4 GEMM # ═══════════════════════════════════════════════════════════════ print("\n" + "=" * 70) print(" TEST 3: wo_b NVFP4 GEMM (CuTeDSL vs BF16 reference)") print("=" * 70) cos_wo_b = test_wo_b_nvfp4(z, wo_b_w, wo_b_sf, wo_b_gs) else: print("\n❌ z is invalid (NaN or None), skipping wo_b test") cos_wo_b = 0.0 # ═══════════════════════════════════════════════════════════════════ # SUMMARY # ═══════════════════════════════════════════════════════════════════ print("\n" + "=" * 70) print(" SUMMARY") print("=" * 70) print(" OLD PATH (FP8 einsum): CRASHES — wo_a has no weight_scale_inv") print(f" NEW PATH (BF16 inv RoPE + BMM): z amax={z.amax():.4f} NaN={torch.isnan(z).any()}") print(f" wo_b NVFP4 cosine: {cos_wo_b:.6f} {'✅' if cos_wo_b >= 0.98 else '❌'}") if cos_wo_b >= 0.98 and z is not None and not torch.isnan(z).any(): print("\n✅ ROOT CAUSE CONFIRMED + FIX VALIDATED") print(" The attention forward was crashing because wo_a is BF16") print(" but the FP8 einsum path expected weight_scale_inv.") print(" Our patched forward (BF16 inv RoPE + BMM) fixes this.") else: print("\n❌ FIX INCOMPLETE — further investigation needed") if __name__ == "__main__": main()