Files
nvfp4-megamoe-kernel/tests/test_model_forward_b200.py

237 lines
9.6 KiB
Python

#!/usr/bin/env python3
"""
Reproduce the vLLM empty-output bug outside the container.
Strategy: Run the model in FULL BF16 (dequantized weights) and compare
against CuTeDSL at each projection. Also check: does the warmup gs
cause issues at inference time?
Key diagnostic: inspect CuTeDSL runner.run() to see if it uses the
fixed warmup gs or recomputes per-call.
Usage (on B200):
source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate
python3 tests/test_model_forward_b200.py
"""
import sys, os, json, torch, torch.nn.functional as F, inspect
from safetensors import safe_open
REPO = "/root/nvfp4-megamoe-kernel"
sys.path.insert(0, REPO)
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEV = "cuda:0"
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
EPS = 1e-6
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
_cache = {}
def P(k, wm, md):
if k in _cache: return _cache[k]
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
t = f.get_tensor(k)
_cache[k] = t
return t
def dequant(w, sf, gs):
d = w.device; lut = E2M1.to(d)
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
O, I2 = w.shape; I = I2*2
u = torch.empty(O, I, dtype=torch.float32, device=d)
u[:,0::2] = lo; u[:,1::2] = hi
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
return (u * bs * gs).to(torch.bfloat16)
def rms(x, w, eps=1e-6):
v = x.float().pow(2).mean(-1, keepdim=True)
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
s = s.permute(1,0).contiguous()
if fused and gs_t.numel() == 2:
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
if g1 != g2:
s32 = s.float(); sp = lw[0] if lw else outf//2
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
else:
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
r.finalize_weights(); r._ensure_initialized()
return r
def main():
torch.cuda.set_device(0)
print("=" * 70)
print(" Diagnose: Why does vLLM produce empty output?")
print("=" * 70)
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
wm = json.load(f)["weight_map"]
G = lambda k: P(k, wm, MODEL).to(DEV)
# ── INSPECT: How does CuTeDSL runner.run() use gs? ────────────────
print("\n--- INSPECTING CuTeDSL runner internals ---")
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
from cutedsl.bridge import quantize_activation_nvfp4
print("\n quantize_activation_nvfp4 signature:")
sig = inspect.signature(quantize_activation_nvfp4)
print(f" {sig}")
print("\n CuTeDSLNvfp4Linear._run_impl source (key lines):")
src = inspect.getsource(CuTeDSLNvfp4Linear._run_impl)
for i, line in enumerate(src.split('\n')):
stripped = line.strip()
if any(kw in stripped for kw in ['global_scale', '_activation', 'quantize', 'return', 'def ']):
print(f" L{i}: {stripped}")
# ── CRITICAL TEST: warmup gs vs per-input gs ──────────────────────
print("\n--- CRITICAL TEST: warmup gs vs per-input gs ---")
p = "model.layers.0"; a = f"{p}.self_attn"
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
# Load embedding + norm weight
emb = G("model.embed_tokens.weight")
anorm = G(f"{p}.input_layernorm.weight")
# Create runner with warmup gs
r = make_runner(qa_w, qa_sf, qa_gs, qa_w.shape[1]*2, qa_w.shape[0])
torch.manual_seed(42)
warmup = torch.randn(1, H, dtype=torch.bfloat16, device=DEV)*2.0
with torch.no_grad():
r.compute_activation_global_scale(warmup)
print(f" Warmup gs (random input amax={warmup.amax():.4f}): {r._activation_global_scale:.8f}")
# Get REAL input (embedding output)
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374], dtype=torch.long, device=DEV)
with torch.no_grad():
hidden = emb[token_ids]
normed = rms(hidden, anorm, EPS)
print(f" Real input (after RMS norm) amax: {normed.amax():.4f}")
# What gs would the real input need?
real_gs = normed.amax().item() / (6.0 * 448.0)
print(f" Correct gs for real input: {real_gs:.8f}")
print(f" Ratio warmup/correct: {r._activation_global_scale / real_gs:.4f}" if real_gs > 0 else " real_gs is 0!")
# Run with warmup gs
with torch.no_grad():
out_warmup = r.run(normed)
# Run with dynamic gs (recompute for this input)
r2 = make_runner(qa_w, qa_sf, qa_gs, qa_w.shape[1]*2, qa_w.shape[0])
with torch.no_grad():
r2.compute_activation_global_scale(normed)
out_dynamic = r2.run(normed)
# BF16 reference
qa_bf16 = dequant(qa_w, qa_sf, qa_gs.item())
with torch.no_grad():
ref = normed @ qa_bf16.T
c_warmup = F.cosine_similarity(out_warmup.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float()).item()
c_dynamic = F.cosine_similarity(out_dynamic.flatten().unsqueeze(0).float(), ref.flatten().unsqueeze(0).float()).item()
print(f"\n q_a cosine vs BF16 (warmup gs): {c_warmup:.6f} {'' if c_warmup>=0.98 else ''}")
print(f" q_a cosine vs BF16 (dynamic gs): {c_dynamic:.6f} {'' if c_dynamic>=0.98 else ''}")
print(f" amax warmup: {out_warmup.amax():.4f} amax dynamic: {out_dynamic.amax():.4f} amax ref: {ref.amax():.4f}")
# ── Test: run FULL model in BF16 (1 layer) then check logits ──────
print("\n--- FULL BF16 model: 1 layer → LM head ---")
lm_head = G("lm_head.weight")
fnorm_w = G("model.norm.weight")
qn = G(f"{a}.q_a_norm.weight")
kvn = G(f"{a}.kv_norm.weight")
fnorm_l0 = G(f"{p}.post_attention_layernorm.weight")
# Dequantize all layer 0 attention weights
qa_bf16 = dequant(qa_w, qa_sf, qa_gs.item())
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
kv_bf16 = dequant(G(f"{a}.kv_proj.weight"), G(f"{a}.kv_proj.weight_scale"), G(f"{a}.kv_proj.weight_scale_2").item())
qb_bf16 = dequant(qb_w, qb_sf, qb_gs.item())
woa = G(f"{a}.o_a_proj.weight") # already BF16
wob_bf16 = dequant(G(f"{a}.o_b_proj.weight"), G(f"{a}.o_b_proj.weight_scale"), G(f"{a}.o_b_proj.weight_scale_2").item())
with torch.no_grad():
x = hidden.clone()
print(f" Input: amax={x.amax():.4f}")
# RMS norm
x = rms(x, anorm, EPS)
# Attention projections (BF16)
qa = x @ qa_bf16.T
kv = x @ kv_bf16.T
qa_n = rms(qa, qn, EPS)
qb = qa_n @ qb_bf16.T
print(f" q_a: amax={qa.amax():.4f}, kv: amax={kv.amax():.4f}, q_b: amax={qb.amax():.4f}")
# Skip attention, use random output
o = torch.randn(len(token_ids), NH, HD, dtype=torch.bfloat16, device=DEV) * 0.1
# wo_a: BMM (o_a_proj is (OG*OL, HPG*HD))
o_grouped = o.view(len(token_ids), OG, HPG * HD).permute(1, 0, 2)
woa_3d = woa.view(OG, OL, HPG * HD)
z = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(len(token_ids), OG * OL)
# wo_b
attn_out = z @ wob_bf16.T
print(f" attn_out (BF16): amax={attn_out.amax():.4f}")
# Skip MoE, just add residual
x = hidden + attn_out
# Final norm + LM head
x_normed = rms(x, fnorm_w, EPS)
logits = x_normed @ lm_head.T
print(f" logits: amax={logits.amax():.4f} NaN={torch.isnan(logits).any()}")
top5 = torch.topk(logits[-1], 5)
print(f" top5 IDs: {top5.indices.tolist()}")
print(f" top5 logits: {[f'{v:.2f}' for v in top5.values.tolist()]}")
log_std = logits[-1].float().std().item()
print(f" logit std: {log_std:.4f}")
# ── KEY INSIGHT: check if the runner re-reads gs at inference time ─
print("\n" + "=" * 70)
print(" KEY: Does runner.run() use FIXED warmup gs or RECOMPUTE?")
print("=" * 70)
# Monkey-patch the gs and see if output changes
r3 = make_runner(qa_w, qa_sf, qa_gs, qa_w.shape[1]*2, qa_w.shape[0])
with torch.no_grad():
r3.compute_activation_global_scale(normed)
gs_original = r3._activation_global_scale
out_original = r3.run(normed).clone()
# Change gs by 10x
r3._activation_global_scale = gs_original * 10.0
out_changed = r3.run(normed).clone()
c_changed = F.cosine_similarity(out_original.flatten().unsqueeze(0).float(), out_changed.flatten().unsqueeze(0).float()).item()
print(f" Original gs: {gs_original:.8f}")
print(f" Changed gs: {gs_original * 10:.8f}")
print(f" Cosine sim after 10x gs change: {c_changed:.6f}")
if abs(c_changed - 1.0) < 0.001:
print(" ➡️ Changing gs has NO effect on output!")
print(" ➡️ The runner recomputes gs internally at inference time.")
print(" ➡️ Warmup gs is IRRELEVANT — the bug is elsewhere.")
else:
print(" ➡️ Changing gs DOES change the output!")
print(" ➡️ The runner uses the warmup gs at inference time.")
print(" ➡️ Wrong warmup gs would cause wrong quantization → garbage.")
if __name__ == "__main__":
main()