#!/usr/bin/env python3 """ Full decoder layer 0 test: ALL components using CuTeDSL kernels, NO vLLM. Tests each attention + FFN projection individually (CuTeDSL vs BF16 ref). Usage (on B200): source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate python3 tests/test_full_layer_b200.py """ import sys, os, json, math, torch, torch.nn.functional as F from safetensors import safe_open REPO = "/root/nvfp4-megamoe-kernel" sys.path.insert(0, REPO) MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" DEV = "cuda:0" # Model config (layer 0, compress_ratio=128 → C4A) H = 7168 NH = 128 HD = 512 NOPE = 448 ROPE = 64 QL = 1536 OL = 1024 OG = 16 HPG = NH // OG HC = 4 SL = 10.0 EPS = 1e-6 INTER = 3072 NT = 4 E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32) _cache = {} def P(k, wm, md): if k in _cache: return _cache[k] with safe_open(os.path.join(md, wm[k]), framework="pt") as f: t = f.get_tensor(k) _cache[k] = t return t def dequant(w, sf, gs): d = w.device; lut = E2M1.to(d) lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()] O, I2 = w.shape; I = I2*2 u = torch.empty(O, I, dtype=torch.float32, device=d) u[:,0::2] = lo; u[:,1::2] = hi bs = sf.float().repeat_interleave(16, dim=1)[:O,:I] return (u * bs * gs).to(torch.bfloat16) def rms(x, w, eps=1e-6): v = x.float().pow(2).mean(-1, keepdim=True) return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype) def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None): from dsv4.layers.linear import Nvfp4Linear fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous() s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf s = s.permute(1,0).contiguous() if fused and gs_t.numel() == 2: g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2) if g1 != g2: s32 = s.float(); sp = lw[0] if lw else outf//2 s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn) else: gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item() r = Nvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device)) r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs] r.finalize_weights(); r._ensure_initialized() return r def cosim(a, b): return F.cosine_similarity(a.flatten().unsqueeze(0).float(), b.flatten().unsqueeze(0).float().to(a.device)).item() def main(): torch.cuda.set_device(0); torch.manual_seed(42) print("="*70) print(" Layer 0: CuTeDSL NVFP4 vs BF16 Reference") print("="*70) with open(os.path.join(MODEL, "model.safetensors.index.json")) as f: wm = json.load(f)["weight_map"] G = lambda k: P(k, wm, MODEL).to(DEV) p = "model.layers.0"; a = f"{p}.self_attn"; m = f"{p}.mlp" # ── Load attention weights (correct checkpoint key names) ───────── print("\n--- Loading weights ---") qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2") qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2") kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2") woa = G(f"{a}.o_a_proj.weight") # BF16 wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2") qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight") anorm = G(f"{p}.input_layernorm.weight"); fnorm = G(f"{p}.post_attention_layernorm.weight") # Compressor (C4A path) ckv_w = G(f"{a}.compressor.kv_proj.weight"); ckv_sf = G(f"{a}.compressor.kv_proj.weight_scale"); ckv_gs = G(f"{a}.compressor.kv_proj.weight_scale_2") cg_w = G(f"{a}.compressor.gate_proj.weight"); cg_sf = G(f"{a}.compressor.gate_proj.weight_scale"); cg_gs = G(f"{a}.compressor.gate_proj.weight_scale_2") ckn = G(f"{a}.compressor.kv_norm.weight") cpb = G(f"{a}.compressor.position_bias") sinks = G(f"{a}.sinks") # MHC hca_fn = G(f"{p}.attn_hc.fn"); hcf_fn = G(f"{p}.ffn_hc.fn") hca_b = G(f"{p}.attn_hc.base"); hcf_b = G(f"{p}.ffn_hc.base") hca_s = G(f"{p}.attn_hc.scale"); hcf_s = G(f"{p}.ffn_hc.scale") for nm, t in [("q_a_proj", qa_w), ("q_b_proj", qb_w), ("kv_proj", kv_w), ("o_a_proj", woa), ("o_b_proj", wob_w), ("comp.kv_proj", ckv_w), ("comp.gate_proj", cg_w), ("sinks", sinks), ("comp.position_bias", cpb), ("attn_hc.fn", hca_fn)]: print(f" {nm}: shape={t.shape} dtype={t.dtype}") # ── Create CuTeDSL runners ──────────────────────────────────────── print("\n--- Creating CuTeDSL runners ---") r_qa = make_runner(qa_w, qa_sf, qa_gs, qa_w.shape[1]*2, qa_w.shape[0]) r_qb = make_runner(qb_w, qb_sf, qb_gs, qb_w.shape[1]*2, qb_w.shape[0]) r_kv = make_runner(kv_w, kv_sf, kv_gs, kv_w.shape[1]*2, kv_w.shape[0]) r_wob = make_runner(wob_w, wob_sf, wob_gs, wob_w.shape[1]*2, wob_w.shape[0]) # Compressor runners r_ckv = make_runner(ckv_w, ckv_sf, ckv_gs, ckv_w.shape[1]*2, ckv_w.shape[0]) r_cg = make_runner(cg_w, cg_sf, cg_gs, cg_w.shape[1]*2, cg_w.shape[0]) print(f" q_a: in={qa_w.shape[1]*2} out={qa_w.shape[0]}") print(f" q_b: in={qb_w.shape[1]*2} out={qb_w.shape[0]}") print(f" kv: in={kv_w.shape[1]*2} out={kv_w.shape[0]}") print(f" wo_b: in={wob_w.shape[1]*2} out={wob_w.shape[0]}") print(f" comp.kv: in={ckv_w.shape[1]*2} out={ckv_w.shape[0]}") print(f" comp.gate: in={cg_w.shape[1]*2} out={cg_w.shape[0]}") # Warmup print(" Warming up...") d1 = torch.randn(NT, H, dtype=torch.bfloat16, device=DEV)*2.0 for r in [r_qa, r_kv, r_ckv, r_cg]: r.compute_activation_global_scale(d1) d2 = torch.randn(NT, QL, dtype=torch.bfloat16, device=DEV)*2.0 r_qb.compute_activation_global_scale(d2) d3 = torch.randn(NT, OG*OL, dtype=torch.bfloat16, device=DEV)*2.0 r_wob.compute_activation_global_scale(d3) print(" Done.") # ── Per-projection BF16 vs CuTeDSL comparison ──────────────────── print("\n" + "="*70) print(" PROJECTION-LEVEL: CuTeDSL vs BF16") print("="*70) torch.manual_seed(123) tx = torch.randn(NT, H, dtype=torch.bfloat16, device=DEV)*2.0 results = {} # q_a_proj with torch.no_grad(): co = r_qa.run(tx) ref = tx @ dequant(qa_w, qa_sf, qa_gs.item()).T c = cosim(co, ref); results['q_a_proj'] = c print(f" q_a_proj: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}") # kv_proj with torch.no_grad(): co = r_kv.run(tx) ref = tx @ dequant(kv_w, kv_sf, kv_gs.item()).T c = cosim(co, ref); results['kv_proj'] = c print(f" kv_proj: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}") # q_b_proj tq = torch.randn(NT, QL, dtype=torch.bfloat16, device=DEV)*2.0 with torch.no_grad(): co = r_qb.run(tq) ref = tq @ dequant(qb_w, qb_sf, qb_gs.item()).T c = cosim(co, ref); results['q_b_proj'] = c print(f" q_b_proj: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}") # wo_b_proj tz = torch.randn(NT, OG*OL, dtype=torch.bfloat16, device=DEV)*2.0 with torch.no_grad(): co = r_wob.run(tz) ref = tz @ dequant(wob_w, wob_sf, wob_gs.item()).T c = cosim(co, ref); results['wo_b_proj'] = c print(f" wo_b_proj: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}") # compressor kv_proj with torch.no_grad(): co = r_ckv.run(tx) ref = tx @ dequant(ckv_w, ckv_sf, ckv_gs.item()).T c = cosim(co, ref); results['comp.kv_proj'] = c print(f" comp.kv_proj: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}") # compressor gate_proj with torch.no_grad(): co = r_cg.run(tx) ref = tx @ dequant(cg_w, cg_sf, cg_gs.item()).T c = cosim(co, ref); results['comp.gate_proj'] = c print(f" comp.gate: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}") # ── Shared expert ───────────────────────────────────────────────── print("\n--- Shared Expert: CuTeDSL vs BF16 ---") from dsv4.layers.shared_expert import Nvfp4SharedExpert sgw = G(f"{m}.shared_experts.gate_proj.weight"); sgsf = G(f"{m}.shared_experts.gate_proj.weight_scale") sggs = G(f"{m}.shared_experts.gate_proj.weight_scale_2").item() suw = G(f"{m}.shared_experts.up_proj.weight"); susf = G(f"{m}.shared_experts.up_proj.weight_scale") sugs = G(f"{m}.shared_experts.up_proj.weight_scale_2").item() sdw = G(f"{m}.shared_experts.down_proj.weight"); sdsf = G(f"{m}.shared_experts.down_proj.weight_scale") sdgs = G(f"{m}.shared_experts.down_proj.weight_scale_2").item() si = INTER sgu_w = torch.cat([sgw, suw], 0); sgu_sf = torch.cat([sgsf, susf], 0) smgs = max(sggs, sugs) if sggs != sugs: s32 = sgu_sf.float(); s32[:si] *= sggs/smgs; s32[si:] *= sugs/smgs sgu_sf = s32.to(torch.float8_e4m3fn) ser = Nvfp4SharedExpert(hidden_size=H, intermediate_size=si, max_num_tokens=8192, device=DEV, swiglu_limit=SL) ser.l1_fp4 = [sgu_w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()] ser.l1_sf = [sgu_sf.permute(1,0).contiguous()]; ser.l1_gs = [smgs] ser.l2_fp4 = [sdw.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()] ser.l2_sf = [sdsf.permute(1,0).contiguous()]; ser.l2_gs = [sdgs] ser.finalize_weights(); ser._ensure_initialized() tse = torch.randn(NT, H, dtype=torch.bfloat16, device=DEV)*2.0 ser.compute_activation_global_scales(tse) with torch.no_grad(): so = ser.run(tse) gb = dequant(sgw, sgsf, sggs); ub = dequant(suw, susf, sugs); db = dequant(sdw, sdsf, sdgs) with torch.no_grad(): g_ = tse @ gb.T; u_ = tse @ ub.T act = F.silu(g_.clamp(max=SL)) * u_.clamp(min=-SL, max=SL) sref = act @ db.T c = cosim(so, sref); results['shared_expert'] = c print(f" shared_expert: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={so.amax():.4f} ref={sref.amax():.4f}") # ── MHC sanity ──────────────────────────────────────────────────── print("\n--- MHC weight shapes ---") print(f" attn_hc.fn: {hca_fn.shape} ffn_hc.fn: {hcf_fn.shape}") print(f" attn_hc.base: {hca_b.shape} ffn_hc.base: {hcf_b.shape}") print(f" attn_hc.scale: {hca_s.shape} ffn_hc.scale: {hcf_s.shape}") print(f" attn_hc.scale values: {hca_s.tolist()}") print(f" sinks: {sinks.shape}") # ── Summary ─────────────────────────────────────────────────────── print("\n" + "="*70) print(" SUMMARY") print("="*70) all_pass = True for name, c in results.items(): status = '✅' if c >= 0.98 else '❌' if c < 0.98: all_pass = False print(f" {name}: {c:.6f} {status}") if all_pass: print("\n All projections pass! CuTeDSL kernels match BF16 reference.") print(" The bug is in vLLM's pipeline, not our kernels.") else: print("\n Some projections FAIL. Need to debug those specific kernels.") if __name__ == "__main__": main()