Files
nvfp4-megamoe-kernel/tests/test_full_layer_b200.py

259 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Full decoder layer 0 test: ALL components using CuTeDSL kernels, NO vLLM.
Tests each attention + FFN projection individually (CuTeDSL vs BF16 ref).
Usage (on B200):
source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate
python3 tests/test_full_layer_b200.py
"""
import sys, os, json, math, torch, torch.nn.functional as F
from safetensors import safe_open
REPO = "/root/nvfp4-megamoe-kernel"
sys.path.insert(0, REPO)
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEV = "cuda:0"
# Model config (layer 0, compress_ratio=128 → C4A)
H = 7168
NH = 128
HD = 512
NOPE = 448
ROPE = 64
QL = 1536
OL = 1024
OG = 16
HPG = NH // OG
HC = 4
SL = 10.0
EPS = 1e-6
INTER = 3072
NT = 4
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
_cache = {}
def P(k, wm, md):
if k in _cache: return _cache[k]
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
t = f.get_tensor(k)
_cache[k] = t
return t
def dequant(w, sf, gs):
d = w.device; lut = E2M1.to(d)
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
O, I2 = w.shape; I = I2*2
u = torch.empty(O, I, dtype=torch.float32, device=d)
u[:,0::2] = lo; u[:,1::2] = hi
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
return (u * bs * gs).to(torch.bfloat16)
def rms(x, w, eps=1e-6):
v = x.float().pow(2).mean(-1, keepdim=True)
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
s = s.permute(1,0).contiguous()
if fused and gs_t.numel() == 2:
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
if g1 != g2:
s32 = s.float(); sp = lw[0] if lw else outf//2
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
else:
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
r.finalize_weights(); r._ensure_initialized()
return r
def cosim(a, b):
return F.cosine_similarity(a.flatten().unsqueeze(0).float(), b.flatten().unsqueeze(0).float().to(a.device)).item()
def main():
torch.cuda.set_device(0); torch.manual_seed(42)
print("="*70)
print(" Layer 0: CuTeDSL NVFP4 vs BF16 Reference")
print("="*70)
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
wm = json.load(f)["weight_map"]
G = lambda k: P(k, wm, MODEL).to(DEV)
p = "model.layers.0"; a = f"{p}.self_attn"; m = f"{p}.mlp"
# ── Load attention weights (correct checkpoint key names) ─────────
print("\n--- Loading weights ---")
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
woa = G(f"{a}.o_a_proj.weight") # BF16
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
anorm = G(f"{p}.input_layernorm.weight"); fnorm = G(f"{p}.post_attention_layernorm.weight")
# Compressor (C4A path)
ckv_w = G(f"{a}.compressor.kv_proj.weight"); ckv_sf = G(f"{a}.compressor.kv_proj.weight_scale"); ckv_gs = G(f"{a}.compressor.kv_proj.weight_scale_2")
cg_w = G(f"{a}.compressor.gate_proj.weight"); cg_sf = G(f"{a}.compressor.gate_proj.weight_scale"); cg_gs = G(f"{a}.compressor.gate_proj.weight_scale_2")
ckn = G(f"{a}.compressor.kv_norm.weight")
cpb = G(f"{a}.compressor.position_bias")
sinks = G(f"{a}.sinks")
# MHC
hca_fn = G(f"{p}.attn_hc.fn"); hcf_fn = G(f"{p}.ffn_hc.fn")
hca_b = G(f"{p}.attn_hc.base"); hcf_b = G(f"{p}.ffn_hc.base")
hca_s = G(f"{p}.attn_hc.scale"); hcf_s = G(f"{p}.ffn_hc.scale")
for nm, t in [("q_a_proj", qa_w), ("q_b_proj", qb_w), ("kv_proj", kv_w),
("o_a_proj", woa), ("o_b_proj", wob_w),
("comp.kv_proj", ckv_w), ("comp.gate_proj", cg_w),
("sinks", sinks), ("comp.position_bias", cpb),
("attn_hc.fn", hca_fn)]:
print(f" {nm}: shape={t.shape} dtype={t.dtype}")
# ── Create CuTeDSL runners ────────────────────────────────────────
print("\n--- Creating CuTeDSL runners ---")
r_qa = make_runner(qa_w, qa_sf, qa_gs, qa_w.shape[1]*2, qa_w.shape[0])
r_qb = make_runner(qb_w, qb_sf, qb_gs, qb_w.shape[1]*2, qb_w.shape[0])
r_kv = make_runner(kv_w, kv_sf, kv_gs, kv_w.shape[1]*2, kv_w.shape[0])
r_wob = make_runner(wob_w, wob_sf, wob_gs, wob_w.shape[1]*2, wob_w.shape[0])
# Compressor runners
r_ckv = make_runner(ckv_w, ckv_sf, ckv_gs, ckv_w.shape[1]*2, ckv_w.shape[0])
r_cg = make_runner(cg_w, cg_sf, cg_gs, cg_w.shape[1]*2, cg_w.shape[0])
print(f" q_a: in={qa_w.shape[1]*2} out={qa_w.shape[0]}")
print(f" q_b: in={qb_w.shape[1]*2} out={qb_w.shape[0]}")
print(f" kv: in={kv_w.shape[1]*2} out={kv_w.shape[0]}")
print(f" wo_b: in={wob_w.shape[1]*2} out={wob_w.shape[0]}")
print(f" comp.kv: in={ckv_w.shape[1]*2} out={ckv_w.shape[0]}")
print(f" comp.gate: in={cg_w.shape[1]*2} out={cg_w.shape[0]}")
# Warmup
print(" Warming up...")
d1 = torch.randn(NT, H, dtype=torch.bfloat16, device=DEV)*2.0
for r in [r_qa, r_kv, r_ckv, r_cg]:
r.compute_activation_global_scale(d1)
d2 = torch.randn(NT, QL, dtype=torch.bfloat16, device=DEV)*2.0
r_qb.compute_activation_global_scale(d2)
d3 = torch.randn(NT, OG*OL, dtype=torch.bfloat16, device=DEV)*2.0
r_wob.compute_activation_global_scale(d3)
print(" Done.")
# ── Per-projection BF16 vs CuTeDSL comparison ────────────────────
print("\n" + "="*70)
print(" PROJECTION-LEVEL: CuTeDSL vs BF16")
print("="*70)
torch.manual_seed(123)
tx = torch.randn(NT, H, dtype=torch.bfloat16, device=DEV)*2.0
results = {}
# q_a_proj
with torch.no_grad(): co = r_qa.run(tx)
ref = tx @ dequant(qa_w, qa_sf, qa_gs.item()).T
c = cosim(co, ref); results['q_a_proj'] = c
print(f" q_a_proj: cosine={c:.6f} {'' if c>=0.98 else ''} amax={co.amax():.4f} ref={ref.amax():.4f}")
# kv_proj
with torch.no_grad(): co = r_kv.run(tx)
ref = tx @ dequant(kv_w, kv_sf, kv_gs.item()).T
c = cosim(co, ref); results['kv_proj'] = c
print(f" kv_proj: cosine={c:.6f} {'' if c>=0.98 else ''} amax={co.amax():.4f} ref={ref.amax():.4f}")
# q_b_proj
tq = torch.randn(NT, QL, dtype=torch.bfloat16, device=DEV)*2.0
with torch.no_grad(): co = r_qb.run(tq)
ref = tq @ dequant(qb_w, qb_sf, qb_gs.item()).T
c = cosim(co, ref); results['q_b_proj'] = c
print(f" q_b_proj: cosine={c:.6f} {'' if c>=0.98 else ''} amax={co.amax():.4f} ref={ref.amax():.4f}")
# wo_b_proj
tz = torch.randn(NT, OG*OL, dtype=torch.bfloat16, device=DEV)*2.0
with torch.no_grad(): co = r_wob.run(tz)
ref = tz @ dequant(wob_w, wob_sf, wob_gs.item()).T
c = cosim(co, ref); results['wo_b_proj'] = c
print(f" wo_b_proj: cosine={c:.6f} {'' if c>=0.98 else ''} amax={co.amax():.4f} ref={ref.amax():.4f}")
# compressor kv_proj
with torch.no_grad(): co = r_ckv.run(tx)
ref = tx @ dequant(ckv_w, ckv_sf, ckv_gs.item()).T
c = cosim(co, ref); results['comp.kv_proj'] = c
print(f" comp.kv_proj: cosine={c:.6f} {'' if c>=0.98 else ''} amax={co.amax():.4f} ref={ref.amax():.4f}")
# compressor gate_proj
with torch.no_grad(): co = r_cg.run(tx)
ref = tx @ dequant(cg_w, cg_sf, cg_gs.item()).T
c = cosim(co, ref); results['comp.gate_proj'] = c
print(f" comp.gate: cosine={c:.6f} {'' if c>=0.98 else ''} amax={co.amax():.4f} ref={ref.amax():.4f}")
# ── Shared expert ─────────────────────────────────────────────────
print("\n--- Shared Expert: CuTeDSL vs BF16 ---")
from cutedsl.shared_expert_pipeline import CuTeDSLSharedExpertRunner
sgw = G(f"{m}.shared_experts.gate_proj.weight"); sgsf = G(f"{m}.shared_experts.gate_proj.weight_scale")
sggs = G(f"{m}.shared_experts.gate_proj.weight_scale_2").item()
suw = G(f"{m}.shared_experts.up_proj.weight"); susf = G(f"{m}.shared_experts.up_proj.weight_scale")
sugs = G(f"{m}.shared_experts.up_proj.weight_scale_2").item()
sdw = G(f"{m}.shared_experts.down_proj.weight"); sdsf = G(f"{m}.shared_experts.down_proj.weight_scale")
sdgs = G(f"{m}.shared_experts.down_proj.weight_scale_2").item()
si = INTER
sgu_w = torch.cat([sgw, suw], 0); sgu_sf = torch.cat([sgsf, susf], 0)
smgs = max(sggs, sugs)
if sggs != sugs:
s32 = sgu_sf.float(); s32[:si] *= sggs/smgs; s32[si:] *= sugs/smgs
sgu_sf = s32.to(torch.float8_e4m3fn)
ser = CuTeDSLSharedExpertRunner(hidden_size=H, intermediate_size=si, max_num_tokens=8192,
device=DEV, swiglu_limit=SL)
ser.l1_fp4 = [sgu_w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()]
ser.l1_sf = [sgu_sf.permute(1,0).contiguous()]; ser.l1_gs = [smgs]
ser.l2_fp4 = [sdw.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()]
ser.l2_sf = [sdsf.permute(1,0).contiguous()]; ser.l2_gs = [sdgs]
ser.finalize_weights(); ser._ensure_initialized()
tse = torch.randn(NT, H, dtype=torch.bfloat16, device=DEV)*2.0
ser.compute_activation_global_scales(tse)
with torch.no_grad(): so = ser.run(tse)
gb = dequant(sgw, sgsf, sggs); ub = dequant(suw, susf, sugs); db = dequant(sdw, sdsf, sdgs)
with torch.no_grad():
g_ = tse @ gb.T; u_ = tse @ ub.T
act = F.silu(g_.clamp(max=SL)) * u_.clamp(min=-SL, max=SL)
sref = act @ db.T
c = cosim(so, sref); results['shared_expert'] = c
print(f" shared_expert: cosine={c:.6f} {'' if c>=0.98 else ''} amax={so.amax():.4f} ref={sref.amax():.4f}")
# ── MHC sanity ────────────────────────────────────────────────────
print("\n--- MHC weight shapes ---")
print(f" attn_hc.fn: {hca_fn.shape} ffn_hc.fn: {hcf_fn.shape}")
print(f" attn_hc.base: {hca_b.shape} ffn_hc.base: {hcf_b.shape}")
print(f" attn_hc.scale: {hca_s.shape} ffn_hc.scale: {hcf_s.shape}")
print(f" attn_hc.scale values: {hca_s.tolist()}")
print(f" sinks: {sinks.shape}")
# ── Summary ───────────────────────────────────────────────────────
print("\n" + "="*70)
print(" SUMMARY")
print("="*70)
all_pass = True
for name, c in results.items():
status = '' if c >= 0.98 else ''
if c < 0.98: all_pass = False
print(f" {name}: {c:.6f} {status}")
if all_pass:
print("\n All projections pass! CuTeDSL kernels match BF16 reference.")
print(" The bug is in vLLM's pipeline, not our kernels.")
else:
print("\n Some projections FAIL. Need to debug those specific kernels.")
if __name__ == "__main__":
main()