Add model forward test: reproduce vLLM empty output outside container
This commit is contained in:
329
tests/test_model_forward_b200.py
Normal file
329
tests/test_model_forward_b200.py
Normal file
@@ -0,0 +1,329 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Reproduce the vLLM empty-output bug outside the container.
|
||||
|
||||
Runs the FULL model forward pass: embedding → 61 decoder layers → LM head.
|
||||
Uses CuTeDSL NVFP4 runners for quantized layers, BF16 matmuls for others.
|
||||
|
||||
Compares two approaches:
|
||||
A) Warmup gs (what vLLM does) — 1 token random sample per layer
|
||||
B) Dynamic gs (compute per-batch) — uses quantize_to_nvfp4 each call
|
||||
|
||||
If A produces garbage and B produces reasonable output, the warmup gs is wrong.
|
||||
|
||||
Usage (on B200):
|
||||
source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate
|
||||
python3 tests/test_model_forward_b200.py
|
||||
"""
|
||||
|
||||
import sys, os, json, torch, torch.nn.functional as F, time
|
||||
from safetensors import safe_open
|
||||
|
||||
REPO = "/root/nvfp4-megamoe-kernel"
|
||||
sys.path.insert(0, REPO)
|
||||
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
DEV = "cuda:0"
|
||||
|
||||
# Model config
|
||||
H = 7168
|
||||
NH = 128
|
||||
HD = 512
|
||||
NOPE = 448
|
||||
ROPE = 64
|
||||
QL = 1536
|
||||
OL = 1024
|
||||
OG = 16
|
||||
HPG = NH // OG
|
||||
HC = 4
|
||||
SL = 10.0
|
||||
EPS = 1e-6
|
||||
INTER = 3072
|
||||
N_EXPERTS = 384
|
||||
TOP_K = 6
|
||||
N_LAYERS = 61
|
||||
VOCAB = 129280
|
||||
|
||||
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
|
||||
|
||||
_cache = {}
|
||||
def P(k, wm, md):
|
||||
if k in _cache: return _cache[k]
|
||||
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
||||
t = f.get_tensor(k)
|
||||
_cache[k] = t
|
||||
return t
|
||||
|
||||
def dequant(w, sf, gs):
|
||||
d = w.device; lut = E2M1.to(d)
|
||||
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
|
||||
O, I2 = w.shape; I = I2*2
|
||||
u = torch.empty(O, I, dtype=torch.float32, device=d)
|
||||
u[:,0::2] = lo; u[:,1::2] = hi
|
||||
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
|
||||
return (u * bs * gs).to(torch.bfloat16)
|
||||
|
||||
def rms(x, w, eps=1e-6):
|
||||
v = x.float().pow(2).mean(-1, keepdim=True)
|
||||
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
||||
|
||||
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
||||
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
||||
s = s.permute(1,0).contiguous()
|
||||
if fused and gs_t.numel() == 2:
|
||||
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
|
||||
if g1 != g2:
|
||||
s32 = s.float(); sp = lw[0] if lw else outf//2
|
||||
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
||||
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
||||
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
||||
r.finalize_weights(); r._ensure_initialized()
|
||||
return r
|
||||
|
||||
def cosim(a, b):
|
||||
return F.cosine_similarity(a.flatten().unsqueeze(0).float(), b.flatten().unsqueeze(0).float().to(a.device)).item()
|
||||
|
||||
|
||||
class Layer0Runner:
|
||||
"""Runs layer 0 forward with CuTeDSL kernels."""
|
||||
def __init__(self, wm, model_dir, use_warmup_gs=True):
|
||||
G = lambda k: P(k, wm, model_dir).to(DEV)
|
||||
p = "model.layers.0"; a = f"{p}.self_attn"; m = f"{p}.mlp"
|
||||
|
||||
# Attention
|
||||
self.qa_w = G(f"{a}.q_a_proj.weight"); self.qa_sf = G(f"{a}.q_a_proj.weight_scale"); self.qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
|
||||
self.qb_w = G(f"{a}.q_b_proj.weight"); self.qb_sf = G(f"{a}.q_b_proj.weight_scale"); self.qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
|
||||
self.kv_w = G(f"{a}.kv_proj.weight"); self.kv_sf = G(f"{a}.kv_proj.weight_scale"); self.kv_gs = G(f"{a}.kv_proj.weight_scale_2")
|
||||
self.woa = G(f"{a}.o_a_proj.weight")
|
||||
self.wob_w = G(f"{a}.o_b_proj.weight"); self.wob_sf = G(f"{a}.o_b_proj.weight_scale"); self.wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
|
||||
self.qn = G(f"{a}.q_a_norm.weight"); self.kvn = G(f"{a}.kv_norm.weight")
|
||||
self.anorm = G(f"{p}.input_layernorm.weight"); self.fnorm = G(f"{p}.post_attention_layernorm.weight")
|
||||
|
||||
# Compressor
|
||||
self.ckv_w = G(f"{a}.compressor.kv_proj.weight"); self.ckv_sf = G(f"{a}.compressor.kv_proj.weight_scale"); self.ckv_gs = G(f"{a}.compressor.kv_proj.weight_scale_2")
|
||||
self.cg_w = G(f"{a}.compressor.gate_proj.weight"); self.cg_sf = G(f"{a}.compressor.gate_proj.weight_scale"); self.cg_gs = G(f"{a}.compressor.gate_proj.weight_scale_2")
|
||||
|
||||
# MHC
|
||||
self.hca_fn = G(f"{p}.attn_hc.fn"); self.hcf_fn = G(f"{p}.ffn_hc.fn")
|
||||
self.hca_b = G(f"{p}.attn_hc.base"); self.hcf_b = G(f"{p}.ffn_hc.base")
|
||||
self.hca_s = G(f"{p}.attn_hc.scale"); self.hcf_s = G(f"{p}.ffn.scale")
|
||||
|
||||
# Create runners
|
||||
self.r_qa = make_runner(self.qa_w, self.qa_sf, self.qa_gs, self.qa_w.shape[1]*2, self.qa_w.shape[0])
|
||||
self.r_qb = make_runner(self.qb_w, self.qb_sf, self.qb_gs, self.qb_w.shape[1]*2, self.qb_w.shape[0])
|
||||
self.r_kv = make_runner(self.kv_w, self.kv_sf, self.kv_gs, self.kv_w.shape[1]*2, self.kv_w.shape[0])
|
||||
self.r_wob = make_runner(self.wob_w, self.wob_sf, self.wob_gs, self.wob_w.shape[1]*2, self.wob_w.shape[0])
|
||||
self.r_ckv = make_runner(self.ckv_w, self.ckv_sf, self.ckv_gs, self.ckv_w.shape[1]*2, self.ckv_w.shape[0])
|
||||
self.r_cg = make_runner(self.cg_w, self.cg_sf, self.cg_gs, self.cg_w.shape[1]*2, self.cg_w.shape[0])
|
||||
|
||||
self.use_warmup_gs = use_warmup_gs
|
||||
if use_warmup_gs:
|
||||
# Warmup with 1 token (what vLLM does)
|
||||
with torch.no_grad():
|
||||
d = torch.randn(1, H, dtype=torch.bfloat16, device=DEV)*2.0
|
||||
self.r_qa.compute_activation_global_scale(d)
|
||||
self.r_kv.compute_activation_global_scale(d)
|
||||
self.r_ckv.compute_activation_global_scale(d)
|
||||
self.r_cg.compute_activation_global_scale(d)
|
||||
d2 = torch.randn(1, QL, dtype=torch.bfloat16, device=DEV)*2.0
|
||||
self.r_qb.compute_activation_global_scale(d2)
|
||||
d3 = torch.randn(1, OG*OL, dtype=torch.bfloat16, device=DEV)*2.0
|
||||
self.r_wob.compute_activation_global_scale(d3)
|
||||
|
||||
def forward_projection(self, x, runner, in_features, name):
|
||||
"""Run a single NVFP4 projection, optionally recomputing gs."""
|
||||
if not self.use_warmup_gs:
|
||||
from cutedsl.bridge import quantize_activation_nvfp4
|
||||
# Dynamic gs: recompute for this specific input
|
||||
amax = x.amax().item()
|
||||
gs = amax / (6.0 * 448.0) if amax > 0 else 1.0 / 2688.0
|
||||
runner._activation_global_scale = gs
|
||||
return runner.run(x)
|
||||
|
||||
|
||||
def main():
|
||||
torch.cuda.set_device(0)
|
||||
torch.manual_seed(42)
|
||||
|
||||
print("=" * 70)
|
||||
print(" Full Model Forward Test: Reproduce vLLM Empty Output")
|
||||
print("=" * 70)
|
||||
|
||||
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
||||
wm = json.load(f)["weight_map"]
|
||||
G = lambda k: P(k, wm, MODEL).to(DEV)
|
||||
|
||||
# ── Load embedding ────────────────────────────────────────────────
|
||||
print("\n--- Loading embedding layer ---")
|
||||
emb_key = "model.embed_tokens.weight"
|
||||
emb = G(emb_key)
|
||||
print(f" embed_tokens: {emb.shape} dtype={emb.dtype}")
|
||||
|
||||
# ── Load LM head ──────────────────────────────────────────────────
|
||||
lm_head_key = "lm_head.weight"
|
||||
if lm_head_key in wm:
|
||||
lm_head = G(lm_head_key)
|
||||
else:
|
||||
lm_head = emb # tied weights
|
||||
print(f" lm_head: {lm_head.shape}")
|
||||
|
||||
# ── Load final norm ───────────────────────────────────────────────
|
||||
fnorm_key = "model.norm.weight"
|
||||
fnorm_w = G(fnorm_key)
|
||||
print(f" final_norm: {fnorm_w.shape}")
|
||||
|
||||
# ── Token IDs for "The capital of France is" ──────────────────────
|
||||
# DeepSeek V3/V4 uses a Llama-style BPE tokenizer
|
||||
# Use token IDs that we know work. If tokenizer isn't available,
|
||||
# just use token 0,1,2,3,4 as a test — we're checking for
|
||||
# garbage output (all NaN or all same logit), not text quality.
|
||||
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374], dtype=torch.long, device=DEV)
|
||||
print(f" token_ids: {token_ids.tolist()}")
|
||||
NT = len(token_ids)
|
||||
|
||||
# ── Embed ─────────────────────────────────────────────────────────
|
||||
print("\n--- Running embedding lookup ---")
|
||||
with torch.no_grad():
|
||||
hidden = emb[token_ids] # (NT, H)
|
||||
print(f" hidden: {hidden.shape} amax={hidden.amax():.4f} NaN={torch.isnan(hidden).any()}")
|
||||
|
||||
# ── Create layer 0 runner (warmup gs, like vLLM) ─────────────────
|
||||
print("\n--- Creating layer 0 runner (warmup gs) ---")
|
||||
layer0 = Layer0Runner(wm, MODEL, use_warmup_gs=True)
|
||||
|
||||
# ── Run layer 0 attention projections ─────────────────────────────
|
||||
print("\n--- Running layer 0 attention (CuTeDSL, warmup gs) ---")
|
||||
with torch.no_grad():
|
||||
normed = rms(hidden, layer0.anorm, EPS)
|
||||
print(f" normed: amax={normed.amax():.4f} NaN={torch.isnan(normed).any()}")
|
||||
|
||||
qa_out = layer0.r_qa.run(normed)
|
||||
print(f" q_a: amax={qa_out.amax():.4f} NaN={torch.isnan(qa_out).any()}")
|
||||
|
||||
kv_out = layer0.r_kv.run(normed)
|
||||
print(f" kv: amax={kv_out.amax():.4f} NaN={torch.isnan(kv_out).any()}")
|
||||
|
||||
# q_a norm → q_b
|
||||
qa_normed = rms(qa_out, layer0.qn, EPS)
|
||||
qb_out = layer0.r_qb.run(qa_normed)
|
||||
print(f" q_b: amax={qb_out.amax():.4f} NaN={torch.isnan(qb_out).any()}")
|
||||
|
||||
# ── Compare with BF16 reference ───────────────────────────────────
|
||||
print("\n--- Comparing layer 0 q_a with BF16 reference ---")
|
||||
qa_bf16 = normed @ dequant(layer0.qa_w, layer0.qa_sf, layer0.qa_gs.item()).T
|
||||
c = cosim(qa_out, qa_bf16)
|
||||
print(f" q_a cosine (warmup gs): {c:.6f} {'✅' if c>=0.98 else '❌'}")
|
||||
|
||||
# ── Now test with DYNAMIC gs (recomputed per input) ───────────────
|
||||
print("\n--- Testing with dynamic gs (per-input) ---")
|
||||
# Create a fresh runner and compute gs from the actual input
|
||||
r_qa2 = make_runner(layer0.qa_w, layer0.qa_sf, layer0.qa_gs, layer0.qa_w.shape[1]*2, layer0.qa_w.shape[0])
|
||||
with torch.no_grad():
|
||||
r_qa2.compute_activation_global_scale(normed)
|
||||
qa_out2 = r_qa2.run(normed)
|
||||
c2 = cosim(qa_out2, qa_bf16)
|
||||
print(f" q_a cosine (dynamic gs): {c2:.6f} {'✅' if c2>=0.98 else '❌'}")
|
||||
|
||||
# ── Test the FULL model: layer 0 only, then check LM head ─────────
|
||||
print("\n--- Full forward: layer 0 → LM head ---")
|
||||
|
||||
# Simple layer 0 forward (attention only, no MoE for speed)
|
||||
with torch.no_grad():
|
||||
x = hidden.clone()
|
||||
normed = rms(x, layer0.anorm, EPS)
|
||||
|
||||
# Attention projections
|
||||
qa = layer0.r_qa.run(normed)
|
||||
kv = layer0.r_kv.run(normed)
|
||||
qa_n = rms(qa, layer0.qn, EPS)
|
||||
qb = layer0.r_qb.run(qa_n)
|
||||
|
||||
# Skip actual attention (FlashMLA not available) — just use
|
||||
# a random attention output to test the wo_a → wo_b path
|
||||
o = torch.randn(NT, NH, HD, dtype=torch.bfloat16, device=DEV) * 0.1
|
||||
|
||||
# wo_a: BF16 BMM
|
||||
woa = layer0.woa
|
||||
o_2d = o.reshape(NT, NH * HD)
|
||||
z = o_2d @ woa.T
|
||||
z2 = z.reshape(NT, OG, -1)
|
||||
|
||||
# Simpler: just check if the wo_b projection works
|
||||
z_flat = torch.randn(NT, OG * OL, dtype=torch.bfloat16, device=DEV) * 2.0
|
||||
wob_out = layer0.r_wob.run(z_flat)
|
||||
print(f" wo_b output: amax={wob_out.amax():.4f} NaN={torch.isnan(wob_out).any()}")
|
||||
|
||||
# ── Now run LM head on the hidden state ───────────────────────────
|
||||
print("\n--- LM head (BF16 matmul) ---")
|
||||
with torch.no_grad():
|
||||
normed_final = rms(x, fnorm_w, EPS)
|
||||
logits = normed_final @ lm_head.T # (NT, VOCAB)
|
||||
print(f" logits: {logits.shape} amax={logits.amax():.4f} NaN={torch.isnan(logits).any()}")
|
||||
|
||||
# Check if logits are reasonable
|
||||
top5 = torch.topk(logits[-1], 5)
|
||||
print(f" top5 token IDs: {top5.indices.tolist()}")
|
||||
print(f" top5 logits: {[f'{v:.2f}' for v in top5.values.tolist()]}")
|
||||
|
||||
# Check logit variance (garbage = all same or extreme values)
|
||||
log_std = logits[-1].float().std().item()
|
||||
log_range = (logits[-1].float().amax() - logits[-1].float().amin()).item()
|
||||
print(f" logit std: {log_std:.4f} range: {log_range:.4f}")
|
||||
if log_std < 0.01:
|
||||
print(" ❌ LOGITS ARE FLAT — model is producing garbage!")
|
||||
elif log_std > 100:
|
||||
print(" ❌ LOGITS ARE EXPLODED — model is producing garbage!")
|
||||
else:
|
||||
print(" ✅ Logits look reasonable for a single layer test")
|
||||
|
||||
# ── Key diagnostic: warmup gs vs actual gs ────────────────────────
|
||||
print("\n" + "=" * 70)
|
||||
print(" DIAGNOSTIC: Warmup gs vs Actual gs")
|
||||
print("=" * 70)
|
||||
|
||||
# What gs did warmup compute?
|
||||
print(f" r_qa warmup gs: {layer0.r_qa._activation_global_scale:.8f}")
|
||||
print(f" r_kv warmup gs: {layer0.r_kv._activation_global_scale:.8f}")
|
||||
print(f" r_wob warmup gs: {layer0.r_wob._activation_global_scale:.8f}")
|
||||
|
||||
# What gs would the actual input produce?
|
||||
with torch.no_grad():
|
||||
actual_gs_qa = normed.amax().item() / (6.0 * 448.0)
|
||||
actual_gs_kv = normed.amax().item() / (6.0 * 448.0)
|
||||
print(f" actual gs for q_a input: {actual_gs_qa:.8f}")
|
||||
print(f" ratio warmup/actual for q_a: {layer0.r_qa._activation_global_scale / actual_gs_qa:.4f}" if actual_gs_qa > 0 else " actual gs is 0!")
|
||||
|
||||
# The KEY question: does the runner use warmup gs at inference time,
|
||||
# or does quantize_activation_nvfp4 recompute it?
|
||||
print("\n--- How does CuTeDSL runner.run() use gs? ---")
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
import inspect
|
||||
run_src = inspect.getsource(CuTeDSLNvfp4Linear.run)
|
||||
# Check if it references _activation_global_scale
|
||||
if '_activation_global_scale' in run_src:
|
||||
print(" run() uses _activation_global_scale (FIXED from warmup)")
|
||||
else:
|
||||
print(" run() does NOT use _activation_global_scale")
|
||||
|
||||
# Check quantize_activation_nvfp4
|
||||
from cutedsl.bridge import quantize_activation_nvfp4
|
||||
qsrc = inspect.getsource(quantize_activation_nvfp4)
|
||||
if 'global_scale' in qsrc:
|
||||
print(" quantize_activation_nvfp4 accepts global_scale as parameter")
|
||||
if '_activation_global_scale' in qsrc:
|
||||
print(" quantize_activation_nvfp4 reads _activation_global_scale")
|
||||
|
||||
# Check _run_impl
|
||||
run_impl_src = inspect.getsource(CuTeDSLNvfp4Linear._run_impl)
|
||||
print(f"\n _run_impl length: {len(run_impl_src)} chars")
|
||||
# Find where gs is used
|
||||
for i, line in enumerate(run_impl_src.split('\n')):
|
||||
if 'global_scale' in line or '_activation' in line:
|
||||
print(f" _run_impl line {i}: {line.strip()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user