259 lines
12 KiB
Python
259 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Full decoder layer 0 test: ALL components using CuTeDSL kernels, NO vLLM.
|
|
|
|
Tests each attention + FFN projection individually (CuTeDSL vs BF16 ref).
|
|
|
|
Usage (on B200):
|
|
source /root/nvfp4-megamoe-kernel/tests/.venv/bin/activate
|
|
python3 tests/test_full_layer_b200.py
|
|
"""
|
|
|
|
import sys, os, json, math, torch, torch.nn.functional as F
|
|
from safetensors import safe_open
|
|
|
|
REPO = "/root/nvfp4-megamoe-kernel"
|
|
sys.path.insert(0, REPO)
|
|
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
DEV = "cuda:0"
|
|
|
|
# Model config (layer 0, compress_ratio=128 → C4A)
|
|
H = 7168
|
|
NH = 128
|
|
HD = 512
|
|
NOPE = 448
|
|
ROPE = 64
|
|
QL = 1536
|
|
OL = 1024
|
|
OG = 16
|
|
HPG = NH // OG
|
|
HC = 4
|
|
SL = 10.0
|
|
EPS = 1e-6
|
|
INTER = 3072
|
|
NT = 4
|
|
|
|
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
|
|
|
|
_cache = {}
|
|
def P(k, wm, md):
|
|
if k in _cache: return _cache[k]
|
|
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
|
t = f.get_tensor(k)
|
|
_cache[k] = t
|
|
return t
|
|
|
|
def dequant(w, sf, gs):
|
|
d = w.device; lut = E2M1.to(d)
|
|
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
|
|
O, I2 = w.shape; I = I2*2
|
|
u = torch.empty(O, I, dtype=torch.float32, device=d)
|
|
u[:,0::2] = lo; u[:,1::2] = hi
|
|
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
|
|
return (u * bs * gs).to(torch.bfloat16)
|
|
|
|
def rms(x, w, eps=1e-6):
|
|
v = x.float().pow(2).mean(-1, keepdim=True)
|
|
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
|
|
|
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
|
|
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
|
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
|
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
|
s = s.permute(1,0).contiguous()
|
|
if fused and gs_t.numel() == 2:
|
|
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
|
|
if g1 != g2:
|
|
s32 = s.float(); sp = lw[0] if lw else outf//2
|
|
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
|
|
else:
|
|
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
|
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
|
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
|
r.finalize_weights(); r._ensure_initialized()
|
|
return r
|
|
|
|
def cosim(a, b):
|
|
return F.cosine_similarity(a.flatten().unsqueeze(0).float(), b.flatten().unsqueeze(0).float().to(a.device)).item()
|
|
|
|
def main():
|
|
torch.cuda.set_device(0); torch.manual_seed(42)
|
|
print("="*70)
|
|
print(" Layer 0: CuTeDSL NVFP4 vs BF16 Reference")
|
|
print("="*70)
|
|
|
|
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
|
wm = json.load(f)["weight_map"]
|
|
G = lambda k: P(k, wm, MODEL).to(DEV)
|
|
|
|
p = "model.layers.0"; a = f"{p}.self_attn"; m = f"{p}.mlp"
|
|
|
|
# ── Load attention weights (correct checkpoint key names) ─────────
|
|
print("\n--- Loading weights ---")
|
|
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
|
|
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
|
|
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
|
|
woa = G(f"{a}.o_a_proj.weight") # BF16
|
|
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
|
|
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
|
|
anorm = G(f"{p}.input_layernorm.weight"); fnorm = G(f"{p}.post_attention_layernorm.weight")
|
|
|
|
# Compressor (C4A path)
|
|
ckv_w = G(f"{a}.compressor.kv_proj.weight"); ckv_sf = G(f"{a}.compressor.kv_proj.weight_scale"); ckv_gs = G(f"{a}.compressor.kv_proj.weight_scale_2")
|
|
cg_w = G(f"{a}.compressor.gate_proj.weight"); cg_sf = G(f"{a}.compressor.gate_proj.weight_scale"); cg_gs = G(f"{a}.compressor.gate_proj.weight_scale_2")
|
|
ckn = G(f"{a}.compressor.kv_norm.weight")
|
|
cpb = G(f"{a}.compressor.position_bias")
|
|
sinks = G(f"{a}.sinks")
|
|
|
|
# MHC
|
|
hca_fn = G(f"{p}.attn_hc.fn"); hcf_fn = G(f"{p}.ffn_hc.fn")
|
|
hca_b = G(f"{p}.attn_hc.base"); hcf_b = G(f"{p}.ffn_hc.base")
|
|
hca_s = G(f"{p}.attn_hc.scale"); hcf_s = G(f"{p}.ffn_hc.scale")
|
|
|
|
for nm, t in [("q_a_proj", qa_w), ("q_b_proj", qb_w), ("kv_proj", kv_w),
|
|
("o_a_proj", woa), ("o_b_proj", wob_w),
|
|
("comp.kv_proj", ckv_w), ("comp.gate_proj", cg_w),
|
|
("sinks", sinks), ("comp.position_bias", cpb),
|
|
("attn_hc.fn", hca_fn)]:
|
|
print(f" {nm}: shape={t.shape} dtype={t.dtype}")
|
|
|
|
# ── Create CuTeDSL runners ────────────────────────────────────────
|
|
print("\n--- Creating CuTeDSL runners ---")
|
|
r_qa = make_runner(qa_w, qa_sf, qa_gs, qa_w.shape[1]*2, qa_w.shape[0])
|
|
r_qb = make_runner(qb_w, qb_sf, qb_gs, qb_w.shape[1]*2, qb_w.shape[0])
|
|
r_kv = make_runner(kv_w, kv_sf, kv_gs, kv_w.shape[1]*2, kv_w.shape[0])
|
|
r_wob = make_runner(wob_w, wob_sf, wob_gs, wob_w.shape[1]*2, wob_w.shape[0])
|
|
|
|
# Compressor runners
|
|
r_ckv = make_runner(ckv_w, ckv_sf, ckv_gs, ckv_w.shape[1]*2, ckv_w.shape[0])
|
|
r_cg = make_runner(cg_w, cg_sf, cg_gs, cg_w.shape[1]*2, cg_w.shape[0])
|
|
|
|
print(f" q_a: in={qa_w.shape[1]*2} out={qa_w.shape[0]}")
|
|
print(f" q_b: in={qb_w.shape[1]*2} out={qb_w.shape[0]}")
|
|
print(f" kv: in={kv_w.shape[1]*2} out={kv_w.shape[0]}")
|
|
print(f" wo_b: in={wob_w.shape[1]*2} out={wob_w.shape[0]}")
|
|
print(f" comp.kv: in={ckv_w.shape[1]*2} out={ckv_w.shape[0]}")
|
|
print(f" comp.gate: in={cg_w.shape[1]*2} out={cg_w.shape[0]}")
|
|
|
|
# Warmup
|
|
print(" Warming up...")
|
|
d1 = torch.randn(NT, H, dtype=torch.bfloat16, device=DEV)*2.0
|
|
for r in [r_qa, r_kv, r_ckv, r_cg]:
|
|
r.compute_activation_global_scale(d1)
|
|
d2 = torch.randn(NT, QL, dtype=torch.bfloat16, device=DEV)*2.0
|
|
r_qb.compute_activation_global_scale(d2)
|
|
d3 = torch.randn(NT, OG*OL, dtype=torch.bfloat16, device=DEV)*2.0
|
|
r_wob.compute_activation_global_scale(d3)
|
|
print(" Done.")
|
|
|
|
# ── Per-projection BF16 vs CuTeDSL comparison ────────────────────
|
|
print("\n" + "="*70)
|
|
print(" PROJECTION-LEVEL: CuTeDSL vs BF16")
|
|
print("="*70)
|
|
torch.manual_seed(123)
|
|
tx = torch.randn(NT, H, dtype=torch.bfloat16, device=DEV)*2.0
|
|
|
|
results = {}
|
|
|
|
# q_a_proj
|
|
with torch.no_grad(): co = r_qa.run(tx)
|
|
ref = tx @ dequant(qa_w, qa_sf, qa_gs.item()).T
|
|
c = cosim(co, ref); results['q_a_proj'] = c
|
|
print(f" q_a_proj: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}")
|
|
|
|
# kv_proj
|
|
with torch.no_grad(): co = r_kv.run(tx)
|
|
ref = tx @ dequant(kv_w, kv_sf, kv_gs.item()).T
|
|
c = cosim(co, ref); results['kv_proj'] = c
|
|
print(f" kv_proj: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}")
|
|
|
|
# q_b_proj
|
|
tq = torch.randn(NT, QL, dtype=torch.bfloat16, device=DEV)*2.0
|
|
with torch.no_grad(): co = r_qb.run(tq)
|
|
ref = tq @ dequant(qb_w, qb_sf, qb_gs.item()).T
|
|
c = cosim(co, ref); results['q_b_proj'] = c
|
|
print(f" q_b_proj: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}")
|
|
|
|
# wo_b_proj
|
|
tz = torch.randn(NT, OG*OL, dtype=torch.bfloat16, device=DEV)*2.0
|
|
with torch.no_grad(): co = r_wob.run(tz)
|
|
ref = tz @ dequant(wob_w, wob_sf, wob_gs.item()).T
|
|
c = cosim(co, ref); results['wo_b_proj'] = c
|
|
print(f" wo_b_proj: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}")
|
|
|
|
# compressor kv_proj
|
|
with torch.no_grad(): co = r_ckv.run(tx)
|
|
ref = tx @ dequant(ckv_w, ckv_sf, ckv_gs.item()).T
|
|
c = cosim(co, ref); results['comp.kv_proj'] = c
|
|
print(f" comp.kv_proj: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}")
|
|
|
|
# compressor gate_proj
|
|
with torch.no_grad(): co = r_cg.run(tx)
|
|
ref = tx @ dequant(cg_w, cg_sf, cg_gs.item()).T
|
|
c = cosim(co, ref); results['comp.gate_proj'] = c
|
|
print(f" comp.gate: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={co.amax():.4f} ref={ref.amax():.4f}")
|
|
|
|
# ── Shared expert ─────────────────────────────────────────────────
|
|
print("\n--- Shared Expert: CuTeDSL vs BF16 ---")
|
|
from cutedsl.shared_expert_pipeline import CuTeDSLSharedExpertRunner
|
|
|
|
sgw = G(f"{m}.shared_experts.gate_proj.weight"); sgsf = G(f"{m}.shared_experts.gate_proj.weight_scale")
|
|
sggs = G(f"{m}.shared_experts.gate_proj.weight_scale_2").item()
|
|
suw = G(f"{m}.shared_experts.up_proj.weight"); susf = G(f"{m}.shared_experts.up_proj.weight_scale")
|
|
sugs = G(f"{m}.shared_experts.up_proj.weight_scale_2").item()
|
|
sdw = G(f"{m}.shared_experts.down_proj.weight"); sdsf = G(f"{m}.shared_experts.down_proj.weight_scale")
|
|
sdgs = G(f"{m}.shared_experts.down_proj.weight_scale_2").item()
|
|
|
|
si = INTER
|
|
sgu_w = torch.cat([sgw, suw], 0); sgu_sf = torch.cat([sgsf, susf], 0)
|
|
smgs = max(sggs, sugs)
|
|
if sggs != sugs:
|
|
s32 = sgu_sf.float(); s32[:si] *= sggs/smgs; s32[si:] *= sugs/smgs
|
|
sgu_sf = s32.to(torch.float8_e4m3fn)
|
|
|
|
ser = CuTeDSLSharedExpertRunner(hidden_size=H, intermediate_size=si, max_num_tokens=8192,
|
|
device=DEV, swiglu_limit=SL)
|
|
ser.l1_fp4 = [sgu_w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()]
|
|
ser.l1_sf = [sgu_sf.permute(1,0).contiguous()]; ser.l1_gs = [smgs]
|
|
ser.l2_fp4 = [sdw.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()]
|
|
ser.l2_sf = [sdsf.permute(1,0).contiguous()]; ser.l2_gs = [sdgs]
|
|
ser.finalize_weights(); ser._ensure_initialized()
|
|
|
|
tse = torch.randn(NT, H, dtype=torch.bfloat16, device=DEV)*2.0
|
|
ser.compute_activation_global_scales(tse)
|
|
with torch.no_grad(): so = ser.run(tse)
|
|
|
|
gb = dequant(sgw, sgsf, sggs); ub = dequant(suw, susf, sugs); db = dequant(sdw, sdsf, sdgs)
|
|
with torch.no_grad():
|
|
g_ = tse @ gb.T; u_ = tse @ ub.T
|
|
act = F.silu(g_.clamp(max=SL)) * u_.clamp(min=-SL, max=SL)
|
|
sref = act @ db.T
|
|
c = cosim(so, sref); results['shared_expert'] = c
|
|
print(f" shared_expert: cosine={c:.6f} {'✅' if c>=0.98 else '❌'} amax={so.amax():.4f} ref={sref.amax():.4f}")
|
|
|
|
# ── MHC sanity ────────────────────────────────────────────────────
|
|
print("\n--- MHC weight shapes ---")
|
|
print(f" attn_hc.fn: {hca_fn.shape} ffn_hc.fn: {hcf_fn.shape}")
|
|
print(f" attn_hc.base: {hca_b.shape} ffn_hc.base: {hcf_b.shape}")
|
|
print(f" attn_hc.scale: {hca_s.shape} ffn_hc.scale: {hcf_s.shape}")
|
|
print(f" attn_hc.scale values: {hca_s.tolist()}")
|
|
print(f" sinks: {sinks.shape}")
|
|
|
|
# ── Summary ───────────────────────────────────────────────────────
|
|
print("\n" + "="*70)
|
|
print(" SUMMARY")
|
|
print("="*70)
|
|
all_pass = True
|
|
for name, c in results.items():
|
|
status = '✅' if c >= 0.98 else '❌'
|
|
if c < 0.98: all_pass = False
|
|
print(f" {name}: {c:.6f} {status}")
|
|
if all_pass:
|
|
print("\n All projections pass! CuTeDSL kernels match BF16 reference.")
|
|
print(" The bug is in vLLM's pipeline, not our kernels.")
|
|
else:
|
|
print("\n Some projections FAIL. Need to debug those specific kernels.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|