From dd7f2627e8d7647b034d29197e9be990b4b58913 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 09:04:19 +0000 Subject: [PATCH] Add full model forward test (WIP), sparse attention test passes --- tests/test_full_model_b200.py | 314 ++++++++++++++++++++++++++++++++++ 1 file changed, 314 insertions(+) create mode 100644 tests/test_full_model_b200.py diff --git a/tests/test_full_model_b200.py b/tests/test_full_model_b200.py new file mode 100644 index 00000000..3d0aaed7 --- /dev/null +++ b/tests/test_full_model_b200.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python3 +""" +Full DeepSeek-V4 Model Forward Test + +Runs the ENTIRE model through our kernel pipeline: +- 61 layers: C128A, C4A, SWA attention + MoE +- All projections: CuTeDSL NVFP4 +- Attention: BF16 (SDPA for SWA, sparse for CSA/HCA) +- KV cache: FP8 quantize/dequant +- MoE: CuTeDSL NVFP4 +- LM head: BF16 + +Outputs logits and checks they're reasonable (not garbage). + +Usage (on B200): + cd /root/nvfp4-megamoe-kernel + PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_full_model_b200.py +""" + +import sys, os, json, torch, torch.nn.functional as F, math, time +from safetensors import safe_open + +REPO = "/root/nvfp4-megamoe-kernel" +sys.path.insert(0, REPO) +MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" +DEV = "cuda:0" + +# Config +H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64 +QL = 1536; OL = 1024; OG = 16; HPG = NH // OG +EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5 +NUM_LAYERS = 61 +NUM_EXPERTS = 384; TOPK = 6 + +E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32) + +_cache = {} +def P(k, wm, md): + if k in _cache: return _cache[k] + with safe_open(os.path.join(md, wm[k]), framework="pt") as f: + t = f.get_tensor(k) + _cache[k] = t + return t + +def dequant(w, sf, gs): + d = w.device; lut = E2M1.to(d) + lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()] + O, I2 = w.shape; I = I2*2 + u = torch.empty(O, I, dtype=torch.float32, device=d) + u[:,0::2] = lo; u[:,1::2] = hi + bs = sf.float().repeat_interleave(16, dim=1)[:O,:I] + return (u * bs * gs).to(torch.bfloat16) + +def rms(x, w, eps=1e-6): + v = x.float().pow(2).mean(-1, keepdim=True) + return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype) + +def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None): + from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear + fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous() + s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf + s = s.permute(1,0).contiguous() + if fused and gs_t.numel() == 2: + g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2) + if g1 != g2: + s32 = s.float(); sp = lw[0] if lw else outf//2 + s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn) + else: + gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item() + r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device)) + r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs] + r.finalize_weights(); r._ensure_initialized() + return r + +def build_cos_sin(max_pos=8192, rope_dim=ROPE): + half = rope_dim // 2 + inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half)) + freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq) + return torch.cat([freqs.cos(), freqs.sin()], dim=-1) + +def apply_gptj_rope(x, positions, cos_sin, nope, rope): + if rope == 0 or x.numel() == 0: return x + half = rope // 2 + cos = cos_sin[positions, :half].to(x.dtype) + sin = cos_sin[positions, half:2*half].to(x.dtype) + if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1) + x_rope = x[..., nope:].clone() + even = x_rope[..., 0::2]; odd = x_rope[..., 1::2] + out = x.clone() + out[..., nope:][..., 0::2] = even * cos - odd * sin + out[..., nope:][..., 1::2] = even * sin + odd * cos + return out + +def apply_inv_gptj_rope(x, positions, cos_sin, nope, rope): + if rope == 0 or x.numel() == 0: return x + half = rope // 2 + cos = cos_sin[positions, :half].to(x.dtype) + sin = cos_sin[positions, half:2*half].to(x.dtype) + if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1) + x_rope = x[..., nope:].clone() + even = x_rope[..., 0::2]; odd = x_rope[..., 1::2] + out = x.clone() + out[..., nope:][..., 0::2] = even * cos + odd * sin + out[..., nope:][..., 1::2] = -even * sin + odd * cos + return out + + +def bf16_causal_attention(q, kv, scale): + """Full causal self-attention.""" + T, NH, HD = q.shape + q_2d = q.reshape(T * NH, HD) + kv_exp = kv.unsqueeze(1).expand(-1, NH, -1).contiguous() + k_2d = kv_exp.permute(1, 0, 2).unsqueeze(1).expand(NH, T, T, -1).contiguous().reshape(T * NH, T, HD) + v_2d = k_2d.clone() + scores = torch.matmul(q_2d.unsqueeze(1), k_2d.transpose(-1, -2)) * scale + qpos = torch.arange(T, device=q.device).unsqueeze(1).repeat(1, NH).reshape(T * NH) + kpos = torch.arange(T, device=q.device).unsqueeze(0) + causal = kpos <= qpos.unsqueeze(1) + scores = scores.squeeze(1).masked_fill(~causal, float('-inf')) + weights = F.softmax(scores.float(), dim=-1).to(q.dtype) + out = torch.matmul(weights.unsqueeze(1), v_2d).squeeze(1) + return out.reshape(T, NH, HD) + + +def make_moe_runner(layer_id, wm, model_path): + """Create CuTeDSL MoE runner for a layer.""" + from cutedsl.runner import CuTeDSLMoERunner + + p = f"model.layers.{layer_id}.mlp" + G = lambda k: P(k, wm, model_path).to(DEV) + + # Gate (router) weight + gate_w = G(f"{p}.gate.weight") # (384, 7168) BF16 + + # Expert weights (NVFP4) + w13_w = G(f"{p}.experts.w13_weight") # (384, 6144, 3584) uint8 + w13_sf = G(f"{p}.experts.w13_weight_scale") # (384, 6144, 448) fp8 + w13_gs = G(f"{p}.experts.w13_weight_scale_2") # (384, 2) + w2_w = G(f"{p}.experts.w2_weight") + w2_sf = G(f"{p}.experts.w2_weight_scale") + w2_gs = G(f"{p}.experts.w2_weight_scale_2") + + # Convert to runner format + l1_fp4 = w13_w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous() + l2_fp4 = w2_w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous() + l1_sf = w13_sf.to(torch.float8_e4m3fn).permute(1,0).contiguous() if w13_sf.dtype != torch.float8_e4m3fn else w13_sf.permute(1,0).contiguous() + l2_sf = w2_sf.to(torch.float8_e4m3fn).permute(1,0).contiguous() if w2_sf.dtype != torch.float8_e4m3fn else w2_sf.permute(1,0).contiguous() + + intermediate_size = 3072 # per expert + runner = CuTeDSLMoERunner( + num_experts=NUM_EXPERTS, + hidden_size=H, + intermediate_size=intermediate_size, + max_num_tokens=8192, + top_k=TOPK, + device=DEV, + ) + + l1_gs_list = w13_gs.tolist() + l2_gs_list = w2_gs.tolist() + + runner.prepare_weights_from_stacked(l1_fp4, l1_sf, l1_gs_list, l2_fp4, l2_sf, l2_gs_list) + + # Shared expert + se_w13_w = G(f"{p}.shared_experts.gate_up_proj.weight") + se_w13_sf = G(f"{p}.shared_experts.gate_up_proj.weight_scale") + se_w13_gs = G(f"{p}.shared_experts.gate_up_proj.weight_scale_2") + se_w2_w = G(f"{p}.shared_experts.down_proj.weight") + se_w2_sf = G(f"{p}.shared_experts.down_proj.weight_scale") + se_w2_gs = G(f"{p}.shared_experts.down_proj.weight_scale_2") + + se_r_gate_up = make_runner(se_w13_w, se_w13_sf, se_w13_gs, H, se_w13_w.shape[0], fused=True, lw=[intermediate_size]) + se_r_down = make_runner(se_w2_w, se_w2_sf, se_w2_gs, intermediate_size, se_w2_w.shape[0]) + + return runner, gate_w, se_r_gate_up, se_r_down + + +def main(): + torch.cuda.set_device(0) + torch.manual_seed(42) + + print("=" * 70) + print(" Full DeepSeek-V4 Model Forward Test") + print(" 61 layers, all CuTeDSL NVFP4 kernels") + print("=" * 70) + + with open(os.path.join(MODEL, "model.safetensors.index.json")) as f: + wm = json.load(f)["weight_map"] + G = lambda k: P(k, wm, MODEL).to(DEV) + + # Load compress_ratios + with open(os.path.join(MODEL, "config.json")) as f: + config = json.load(f) + compress_ratios = config["compress_ratios"] + + # Global weights + emb = G("model.embed_tokens.weight") + fnorm_w = G("model.norm.weight") + lm_head = G("lm_head.weight") + cos_sin = build_cos_sin(max_pos=8192).to(DEV) + + # Input + NT = 6 + token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374], dtype=torch.long, device=DEV) + positions = torch.arange(NT, dtype=torch.int64, device=DEV) + + print(f" Input: {NT} tokens: {token_ids.tolist()}") + print(f" Model: {NUM_LAYERS} layers, {NUM_EXPERTS} experts, top-{TOPK}") + + with torch.no_grad(): + hidden = emb[token_ids] + + for layer_id in range(NUM_LAYERS): + cr = max(1, compress_ratios[layer_id]) + layer_type = "SWA" if cr <= 1 else f"C{cr}A" + p = f"model.layers.{layer_id}" + a = f"{p}.self_attn" + m = f"{p}.mlp" + + # Layer norms + anorm = G(f"{p}.input_layernorm.weight") + fnorm = G(f"{p}.post_attention_layernorm.weight") + + # ── Attention ──────────────────────────────────────────── + qn = G(f"{a}.q_a_norm.weight") + kvn = G(f"{a}.kv_norm.weight") + woa = G(f"{a}.o_a_proj.weight") + + r_qa = make_runner(G(f"{a}.q_a_proj.weight"), G(f"{a}.q_a_proj.weight_scale"), G(f"{a}.q_a_proj.weight_scale_2"), H, G(f"{a}.q_a_proj.weight").shape[0]) + r_qb = make_runner(G(f"{a}.q_b_proj.weight"), G(f"{a}.q_b_proj.weight_scale"), G(f"{a}.q_b_proj.weight_scale_2"), QL, G(f"{a}.q_b_proj.weight").shape[0]) + r_kv = make_runner(G(f"{a}.kv_proj.weight"), G(f"{a}.kv_proj.weight_scale"), G(f"{a}.kv_proj.weight_scale_2"), H, G(f"{a}.kv_proj.weight").shape[0]) + r_wob = make_runner(G(f"{a}.o_b_proj.weight"), G(f"{a}.o_b_proj.weight_scale"), G(f"{a}.o_b_proj.weight_scale_2"), OG*OL, G(f"{a}.o_b_proj.weight").shape[0]) + + normed = rms(hidden, anorm, EPS) + + qa = r_qa.run(normed) + kv = r_kv.run(normed) + qa_n = rms(qa, qn, EPS) + kv_n = rms(kv, kvn, EPS) + q = r_qb.run(qa_n).view(NT, NH, HD) + q_rope = apply_gptj_rope(q, positions, cos_sin, NOPE, ROPE) + + # Attention (BF16 causal — simplified, no sparse index yet) + o_attn = bf16_causal_attention(q_rope, kv_n, SCALE) + + # o_a: inverse RoPE + BMM + o_inv = apply_inv_gptj_rope(o_attn, positions, cos_sin, NOPE, ROPE) + o_grouped = o_inv.view(NT, OG, HPG * HD).permute(1, 0, 2) + woa_3d = woa.view(OG, OL, HPG * HD) + z = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(NT, OG * OL) + attn_out = r_wob.run(z) + + # Residual + hidden = hidden + attn_out + + # ── MoE ────────────────────────────────────────────────── + # For speed, only load MoE for first 3 layers + last layer + if layer_id < 3 or layer_id == NUM_LAYERS - 1: + fnormed = rms(hidden, fnorm, EPS) + + # Simplified: just use BF16 shared expert for now + # (full MoE test is in layertest.py) + se_w13_w = G(f"{m}.shared_experts.gate_up_proj.weight") + se_w13_sf = G(f"{m}.shared_experts.gate_up_proj.weight_scale") + se_w13_gs = G(f"{m}.shared_experts.gate_up_proj.weight_scale_2") + se_w2_w = G(f"{m}.shared_experts.down_proj.weight") + se_w2_sf = G(f"{m}.shared_experts.down_proj.weight_scale") + se_w2_gs = G(f"{m}.shared_experts.down_proj.weight_scale_2") + + se_gate_up = make_runner(se_w13_w, se_w13_sf, se_w13_gs, H, se_w13_w.shape[0], fused=True, lw=[3072]) + se_down = make_runner(se_w2_w, se_w2_sf, se_w2_gs, 3072, se_w2_w.shape[0]) + + # Shared expert only (skip routed experts for speed) + se_out = se_gate_up.run(fnormed) + gate, up = se_out[:, :3072], se_out[:, 3072:] + se_activated = F.silu(gate) * up + se_final = se_down.run(se_activated) + + hidden = hidden + se_final + else: + # Skip MoE for middle layers (just use residual) + # This is WRONG for correctness but saves time + fnormed = rms(hidden, fnorm, EPS) + hidden = hidden + fnormed # placeholder + + if layer_id % 10 == 0 or layer_id == NUM_LAYERS - 1: + print(f" Layer {layer_id} ({layer_type}): hidden amax={hidden.amax():.4f} NaN={torch.isnan(hidden).any()}") + + # Cleanup per-layer weights + torch.cuda.empty_cache() + _cache.clear() + + # Final norm + LM head + x_n = rms(hidden, fnorm_w, EPS) + logits = x_n @ lm_head.T + + print(f"\n Final logits: amax={logits.amax():.4f} std={logits[-1].float().std():.4f}") + top5 = torch.topk(logits[-1], 5) + print(f" Top 5 tokens: {top5.indices.tolist()}") + print(f" Top 5 probs: {F.softmax(top5.values.float(), dim=0).tolist()}") + + log_std = logits[-1].float().std().item() + if 0.5 < log_std < 50: + print(f" ✅ Logits look reasonable (std={log_std:.4f})") + else: + print(f" ❌ Logits are garbage (std={log_std:.4f})") + + print(f"\n{'='*70}") + print(f" DONE") + print(f"{'='*70}") + + +if __name__ == "__main__": + main()