#!/usr/bin/env python3 """ Reproduce the vLLM empty-output bug outside the container. Strategy: Run the model in FULL BF16 (dequantized weights) and compare against CuTeDSL at each projection. Also check: does the warmup gs cause issues at inference time? Key diagnostic: inspect CuTeDSL runner.run() to see if it uses the fixed warmup gs or recomputes per-call. Usage (on B200): source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate python3 tests/test_model_forward_b200.py """ import sys, os, json, torch, torch.nn.functional as F, inspect from safetensors import safe_open REPO = "/root/nvfp4-megamoe-kernel" sys.path.insert(0, REPO) MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" DEV = "cuda:0" H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64 QL = 1536; OL = 1024; OG = 16; HPG = NH // OG EPS = 1e-6 E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32) _cache = {} def P(k, wm, md): if k in _cache: return _cache[k] with safe_open(os.path.join(md, wm[k]), framework="pt") as f: t = f.get_tensor(k) _cache[k] = t return t def dequant(w, sf, gs): d = w.device; lut = E2M1.to(d) lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()] O, I2 = w.shape; I = I2*2 u = torch.empty(O, I, dtype=torch.float32, device=d) u[:,0::2] = lo; u[:,1::2] = hi bs = sf.float().repeat_interleave(16, dim=1)[:O,:I] return (u * bs * gs).to(torch.bfloat16) def rms(x, w, eps=1e-6): v = x.float().pow(2).mean(-1, keepdim=True) return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype) def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None): from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous() s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf s = s.permute(1,0).contiguous() if fused and gs_t.numel() == 2: g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2) if g1 != g2: s32 = s.float(); sp = lw[0] if lw else outf//2 s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn) else: gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item() r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device)) r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs] r.finalize_weights(); r._ensure_initialized() return r def main(): torch.cuda.set_device(0) print("=" * 70) print(" Diagnose: Why does vLLM produce empty output?") print("=" * 70) with open(os.path.join(MODEL, "model.safetensors.index.json")) as f: wm = json.load(f)["weight_map"] G = lambda k: P(k, wm, MODEL).to(DEV) # ── INSPECT: How does CuTeDSL runner.run() use gs? ──────────────── print("\n--- INSPECTING CuTeDSL runner internals ---") from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear from cutedsl.bridge import quantize_activation_nvfp4 print("\n quantize_activation_nvfp4 signature:") sig = inspect.signature(quantize_activation_nvfp4) print(f" {sig}") print("\n CuTeDSLNvfp4Linear._run_impl source (key lines):") src = inspect.getsource(CuTeDSLNvfp4Linear._run_impl) for i, line in enumerate(src.split('\n')): stripped = line.strip() if any(kw in stripped for kw in ['global_scale', '_activation', 'quantize', 'return', 'def ']): print(f" L{i}: {stripped}") # ── CRITICAL TEST: warmup gs vs per-input gs ────────────────────── print("\n--- CRITICAL TEST: warmup gs vs per-input gs ---") p = "model.layers.0"; a = f"{p}.self_attn" qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2") # Load embedding + norm weight emb = G("model.embed_tokens.weight") anorm = G(f"{p}.input_layernorm.weight") # Create runner with warmup gs r = make_runner(qa_w, qa_sf, qa_gs, qa_w.shape[1]*2, qa_w.shape[0]) torch.manual_seed(42) warmup = torch.randn(1, H, dtype=torch.bfloat16, device=DEV)*2.0 with torch.no_grad(): r.compute_activation_global_scale(warmup) print(f" Warmup gs (random input amax={warmup.amax():.4f}): {r._activation_global_scale:.8f}") # Get REAL input (embedding output) token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374], dtype=torch.long, device=DEV) with torch.no_grad(): hidden = emb[token_ids] normed = rms(hidden, anorm, EPS) print(f" Real input (after RMS norm) amax: {normed.amax():.4f}") # What gs would the real input need? real_gs = normed.amax().item() / (6.0 * 448.0) print(f" Correct gs for real input: {real_gs:.8f}") print(f" Ratio warmup/correct: {r._activation_global_scale / real_gs:.4f}" if real_gs > 0 else " real_gs is 0!") # Run with warmup gs with torch.no_grad(): out_warmup = r.run(normed) # Run with dynamic gs (recompute for this input) r2 = make_runner(qa_w, qa_sf, qa_gs, qa_w.shape[1]*2, qa_w.shape[0]) with torch.no_grad(): r2.compute_activation_global_scale(normed) out_dynamic = r2.run(normed) # BF16 reference qa_bf16 = dequant(qa_w, qa_sf, qa_gs.item()) with torch.no_grad(): ref = normed @ qa_bf16.T c_warmup = F.cosine_similarity(out_warmup.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float()).item() c_dynamic = F.cosine_similarity(out_dynamic.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float()).item() print(f"\n q_a cosine vs BF16 (warmup gs): {c_warmup:.6f} {'✅' if c_warmup>=0.98 else '❌'}") print(f" q_a cosine vs BF16 (dynamic gs): {c_dynamic:.6f} {'✅' if c_dynamic>=0.98 else '❌'}") print(f" amax warmup: {out_warmup.amax():.4f} amax dynamic: {out_dynamic.amax():.4f} amax ref: {ref.amax():.4f}") # ── Test: run FULL model in BF16 (1 layer) then check logits ────── print("\n--- FULL BF16 model: 1 layer → LM head ---") lm_head = G("lm_head.weight") fnorm_w = G("model.norm.weight") qn = G(f"{a}.q_a_norm.weight") kvn = G(f"{a}.kv_norm.weight") fnorm_l0 = G(f"{p}.post_attention_layernorm.weight") # Dequantize all layer 0 attention weights qa_bf16 = dequant(qa_w, qa_sf, qa_gs.item()) qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2") kv_bf16 = dequant(G(f"{a}.kv_proj.weight"), G(f"{a}.kv_proj.weight_scale"), G(f"{a}.kv_proj.weight_scale_2").item()) qb_bf16 = dequant(qb_w, qb_sf, qb_gs.item()) woa = G(f"{a}.o_a_proj.weight") # already BF16 wob_bf16 = dequant(G(f"{a}.o_b_proj.weight"), G(f"{a}.o_b_proj.weight_scale"), G(f"{a}.o_b_proj.weight_scale_2").item()) with torch.no_grad(): x = hidden.clone() print(f" Input: amax={x.amax():.4f}") # RMS norm x = rms(x, anorm, EPS) # Attention projections (BF16) qa = x @ qa_bf16.T kv = x @ kv_bf16.T qa_n = rms(qa, qn, EPS) qb = qa_n @ qb_bf16.T print(f" q_a: amax={qa.amax():.4f}, kv: amax={kv.amax():.4f}, q_b: amax={qb.amax():.4f}") # Skip attention, use random output o = torch.randn(len(token_ids), NH, HD, dtype=torch.bfloat16, device=DEV) * 0.1 # wo_a: BMM (o_a_proj is (OG*OL, HPG*HD)) o_grouped = o.view(len(token_ids), OG, HPG * HD).permute(1, 0, 2) woa_3d = woa.view(OG, OL, HPG * HD) z = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(len(token_ids), OG * OL) # wo_b attn_out = z @ wob_bf16.T print(f" attn_out (BF16): amax={attn_out.amax():.4f}") # Skip MoE, just add residual x = hidden + attn_out # Final norm + LM head x_normed = rms(x, fnorm_w, EPS) logits = x_normed @ lm_head.T print(f" logits: amax={logits.amax():.4f} NaN={torch.isnan(logits).any()}") top5 = torch.topk(logits[-1], 5) print(f" top5 IDs: {top5.indices.tolist()}") print(f" top5 logits: {[f'{v:.2f}' for v in top5.values.tolist()]}") log_std = logits[-1].float().std().item() print(f" logit std: {log_std:.4f}") # ── KEY INSIGHT: check if the runner re-reads gs at inference time ─ print("\n" + "=" * 70) print(" KEY: Does runner.run() use FIXED warmup gs or RECOMPUTE?") print("=" * 70) # Monkey-patch the gs and see if output changes r3 = make_runner(qa_w, qa_sf, qa_gs, qa_w.shape[1]*2, qa_w.shape[0]) with torch.no_grad(): r3.compute_activation_global_scale(normed) gs_original = r3._activation_global_scale out_original = r3.run(normed).clone() # Change gs by 10x r3._activation_global_scale = gs_original * 10.0 out_changed = r3.run(normed).clone() c_changed = F.cosine_similarity(out_original.flatten().unsqueeze(0).float(), out_changed.flatten().unsqueeze(0).float()).item() print(f" Original gs: {gs_original:.8f}") print(f" Changed gs: {gs_original * 10:.8f}") print(f" Cosine sim after 10x gs change: {c_changed:.6f}") if abs(c_changed - 1.0) < 0.001: print(" ➡️ Changing gs has NO effect on output!") print(" ➡️ The runner recomputes gs internally at inference time.") print(" ➡️ Warmup gs is IRRELEVANT — the bug is elsewhere.") else: print(" ➡️ Changing gs DOES change the output!") print(" ➡️ The runner uses the warmup gs at inference time.") print(" ➡️ Wrong warmup gs would cause wrong quantization → garbage.") if __name__ == "__main__": main()