"""Standalone test: Attention projections using CuTeDSL NVFP4 linear runner. Tests q_a_proj, q_b_proj, kv_proj, o_b_proj against BF16 reference. o_a_proj is BF16 (not NVFP4) — not tested here. Usage: python3 test_attention.py """ import torch import torch.nn.functional as F import sys, os, json from safetensors import safe_open MODEL_PATH = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" DEVICE = "cuda:0" LAYER_IDX = 0 HIDDEN_SIZE = 7168 NUM_TOKENS = 4 E2M1_LUT = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5, -2., -3., -4., -6.], dtype=torch.float32) _cache = {} def load_tensor(key, wm, model_dir): if key in _cache: return _cache[key] shard_path = os.path.join(model_dir, wm[key]) with safe_open(shard_path, framework="pt") as f: t = f.get_tensor(key) _cache[key] = t return t def dequant_nvfp4(packed_uint8, scale_e4m3, global_scale): """Dequantize NVFP4 weight to BF16 for reference.""" device = packed_uint8.device lut = E2M1_LUT.to(device) lower = lut[(packed_uint8 & 0x0F).long()] upper = lut[((packed_uint8 >> 4) & 0x0F).long()] out_features = packed_uint8.shape[0] in_features = packed_uint8.shape[1] * 2 unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=device) unpacked[:, 0::2] = lower unpacked[:, 1::2] = upper block_scale = scale_e4m3.float() block_expanded = block_scale.repeat_interleave(16, dim=1)[:out_features, :in_features] return (unpacked * block_expanded * global_scale).to(torch.bfloat16) def test_projection(name, weight, weight_sf, weight_gs, hidden_states, in_features, out_features): """Test a single NVFP4 projection.""" sys.path.insert(0, "/root/nvfp4-megamoe-kernel") from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear # Convert weight to CuTeDSL format: (out, in_packed) uint8 → (in_packed, out) float4 fp4 = [weight.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()] sf = [weight_sf.permute(1, 0).contiguous()] gs = [weight_gs] runner = CuTeDSLNvfp4Linear( in_features=in_features, out_features=out_features, max_num_tokens=8192, device=DEVICE, ) runner.fp4 = fp4 runner.sf = sf runner.gs = gs runner.finalize_weights() # Warmup runner._ensure_initialized() runner.compute_activation_global_scale(hidden_states) # Run CuTeDSL with torch.no_grad(): output = runner.run(hidden_states) # BF16 reference bf16_w = dequant_nvfp4(weight, weight_sf, weight_gs) with torch.no_grad(): ref = hidden_states @ bf16_w.T # Compare cos = F.cosine_similarity(ref.flatten().unsqueeze(0), output.flatten().unsqueeze(0)).item() mse = (ref - output).pow(2).mean().item() status = "✅" if cos >= 0.98 else "❌" print(f" {name}: cosine={cos:.6f} MSE={mse:.6e} amax_ref={ref.amax():.4f} amax_out={output.amax():.4f} {status}") return cos def main(): torch.cuda.set_device(0) torch.manual_seed(42) with open(os.path.join(MODEL_PATH, "model.safetensors.index.json")) as f: wm = json.load(f)["weight_map"] P = lambda key: load_tensor(key, wm, MODEL_PATH).to(DEVICE) prefix = f"model.layers.{LAYER_IDX}.self_attn" print("=== Attention Projection Tests (CuTeDSL NVFP4 Linear) ===\n") # Load weights and determine dimensions from shapes projs = { "q_a_proj": {"key": f"{prefix}.q_a_proj"}, "q_b_proj": {"key": f"{prefix}.q_b_proj"}, "kv_proj": {"key": f"{prefix}.kv_proj"}, "o_b_proj": {"key": f"{prefix}.o_b_proj"}, } for name, info in projs.items(): key = info["key"] w = P(f"{key}.weight") sf = P(f"{key}.weight_scale") gs = P(f"{key}.weight_scale_2").item() out_features = w.shape[0] in_features = w.shape[1] * 2 # unpacked info["weight"] = w info["sf"] = sf info["gs"] = gs info["in_features"] = in_features info["out_features"] = out_features print(f" {name}: weight={w.shape} → in={in_features} out={out_features} gs={gs:.8f}") print() # Test each projection # q_a_proj: input is hidden_states (HIDDEN_SIZE=7168) hidden = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device=DEVICE) * 2.0 cos_qa = test_projection("q_a_proj", projs["q_a_proj"]["weight"], projs["q_a_proj"]["sf"], projs["q_a_proj"]["gs"], hidden, projs["q_a_proj"]["in_features"], projs["q_a_proj"]["out_features"]) # q_b_proj: input is q_a output (1536 features) q_a_out_features = projs["q_a_proj"]["out_features"] q_a_out = torch.randn(NUM_TOKENS, q_a_out_features, dtype=torch.bfloat16, device=DEVICE) * 2.0 cos_qb = test_projection("q_b_proj", projs["q_b_proj"]["weight"], projs["q_b_proj"]["sf"], projs["q_b_proj"]["gs"], q_a_out, projs["q_b_proj"]["in_features"], projs["q_b_proj"]["out_features"]) # kv_proj: input is hidden_states (7168) cos_kv = test_projection("kv_proj", projs["kv_proj"]["weight"], projs["kv_proj"]["sf"], projs["kv_proj"]["gs"], hidden, projs["kv_proj"]["in_features"], projs["kv_proj"]["out_features"]) # o_b_proj: input is o_a output (16384 features after attention) o_b_in_features = projs["o_b_proj"]["in_features"] o_b_input = torch.randn(NUM_TOKENS, o_b_in_features, dtype=torch.bfloat16, device=DEVICE) * 2.0 cos_ob = test_projection("o_b_proj", projs["o_b_proj"]["weight"], projs["o_b_proj"]["sf"], projs["o_b_proj"]["gs"], o_b_input, projs["o_b_proj"]["in_features"], projs["o_b_proj"]["out_features"]) print(f"\n=== SUMMARY ===") results = {"q_a_proj": cos_qa, "q_b_proj": cos_qb, "kv_proj": cos_kv, "o_b_proj": cos_ob} all_pass = True for name, cos in results.items(): status = "✅" if cos >= 0.98 else "❌" if cos < 0.98: all_pass = False print(f" {name}: cosine={cos:.6f} {status}") if all_pass: print("\n✅ ALL PASS") else: print("\n❌ SOME FAILED") if __name__ == "__main__": main()