Add full layer NaN test (attention + MoE, multi-layer chain)
This commit is contained in:
348
tests/test_full_layer_nan_b200.py
Normal file
348
tests/test_full_layer_nan_b200.py
Normal file
@@ -0,0 +1,348 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DeepSeek-V4 Full Layer Forward Test
|
||||
|
||||
Tests a complete transformer layer (attention + MoE) with real weights.
|
||||
If this produces NaN, we can bisect which component causes it.
|
||||
|
||||
Usage (on B200):
|
||||
cd /root/nvfp4-megamoe-kernel
|
||||
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_full_layer_nan_b200.py
|
||||
"""
|
||||
|
||||
import sys, os, json, torch, torch.nn.functional as F
|
||||
from safetensors import safe_open
|
||||
|
||||
REPO = "/root/nvfp4-megamoe-kernel"
|
||||
sys.path.insert(0, REPO)
|
||||
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
DEV = "cuda:0"
|
||||
|
||||
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
|
||||
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
|
||||
INTERMEDIATE = 3072
|
||||
TOPK = 6
|
||||
EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5
|
||||
|
||||
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
|
||||
|
||||
_cache = {}
|
||||
def P(k, wm, md):
|
||||
if k in _cache: return _cache[k]
|
||||
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
||||
t = f.get_tensor(k)
|
||||
_cache[k] = t
|
||||
return t
|
||||
|
||||
def rms(x, w, eps=1e-6):
|
||||
v = x.float().pow(2).mean(-1, keepdim=True)
|
||||
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
||||
|
||||
def make_runner(w, sf, gs_t, inf, outf):
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
||||
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
||||
s = s.permute(1,0).contiguous()
|
||||
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
||||
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
||||
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
||||
r.finalize_weights(); r._ensure_initialized()
|
||||
return r
|
||||
|
||||
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
|
||||
half = rope_dim // 2
|
||||
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
|
||||
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
|
||||
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
|
||||
|
||||
def apply_gptj_rope(x, positions, cos_sin, nope_dim, rope_dim):
|
||||
if rope_dim == 0 or x.numel() == 0: return x
|
||||
half = rope_dim // 2
|
||||
cos = cos_sin[positions, :half].to(x.dtype)
|
||||
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
||||
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
||||
x_rope = x[..., nope_dim:].clone()
|
||||
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
||||
out = x.clone()
|
||||
out[..., nope_dim:][..., 0::2] = even * cos - odd * sin
|
||||
out[..., nope_dim:][..., 1::2] = even * sin + odd * cos
|
||||
return out
|
||||
|
||||
def apply_inv_gptj_rope(x, positions, cos_sin, nope_dim, rope_dim):
|
||||
if rope_dim == 0 or x.numel() == 0: return x
|
||||
half = rope_dim // 2
|
||||
cos = cos_sin[positions, :half].to(x.dtype)
|
||||
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
||||
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
||||
x_rope = x[..., nope_dim:].clone()
|
||||
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
||||
out = x.clone()
|
||||
out[..., nope_dim:][..., 0::2] = even * cos + odd * sin
|
||||
out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos
|
||||
return out
|
||||
|
||||
def kv_quantize_fp8(kv_bf16):
|
||||
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||||
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device)
|
||||
scale = fp8_max / amax
|
||||
kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn)
|
||||
inv_scale = (amax / fp8_max).to(torch.bfloat16)
|
||||
return kv_fp8, inv_scale
|
||||
|
||||
def kv_dequantize_fp8(kv_fp8, inv_scale):
|
||||
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
|
||||
|
||||
def causal_prefill_attention(q, kv, scale):
|
||||
T, NH, HD = q.shape
|
||||
q_t = q.permute(1, 0, 2)
|
||||
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1)
|
||||
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale)
|
||||
return out.permute(1, 0, 2)
|
||||
|
||||
|
||||
def test_full_layer(layer_id, num_tokens=8, num_moe_experts=16):
|
||||
"""Test a complete transformer layer with attention + MoE."""
|
||||
from cutedsl.runner import CuTeDSLMoERunner
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.empty_cache()
|
||||
_cache.clear()
|
||||
|
||||
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
||||
wm = json.load(f)["weight_map"]
|
||||
G = lambda k: P(k, wm, MODEL).to(DEV)
|
||||
|
||||
p = f"model.layers.{layer_id}"; a = f"{p}.self_attn"
|
||||
m = f"{p}.mlp"
|
||||
cr = 128 if layer_id == 0 else (0 if layer_id == 60 else 4)
|
||||
lt = f"C{cr}A" if cr > 1 else "SWA"
|
||||
|
||||
emb = G("model.embed_tokens.weight")
|
||||
anorm = G(f"{p}.input_layernorm.weight")
|
||||
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
|
||||
woa = G(f"{a}.o_a_proj.weight")
|
||||
fnorm = G(f"{p}.post_attention_layernorm.weight")
|
||||
|
||||
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
|
||||
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
|
||||
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
|
||||
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
|
||||
|
||||
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
|
||||
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
|
||||
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
|
||||
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
|
||||
cos_sin = build_cos_sin(max_pos=4096).to(DEV)
|
||||
woa_3d = woa.view(OG, OL, HPG * HD)
|
||||
|
||||
# MoE weights (only first num_moe_experts to fit in memory)
|
||||
gate_ws, gate_sfs, gate_gss = [], [], []
|
||||
up_ws, up_sfs, up_gss = [], [], []
|
||||
down_ws, down_sfs, down_gss = [], [], []
|
||||
for i in range(num_moe_experts):
|
||||
e = f"{m}.experts.{i}"
|
||||
gate_ws.append(G(f"{e}.gate_proj.weight"))
|
||||
gate_sfs.append(G(f"{e}.gate_proj.weight_scale"))
|
||||
gate_gss.append(G(f"{e}.gate_proj.weight_scale_2"))
|
||||
up_ws.append(G(f"{e}.up_proj.weight"))
|
||||
up_sfs.append(G(f"{e}.up_proj.weight_scale"))
|
||||
up_gss.append(G(f"{e}.up_proj.weight_scale_2"))
|
||||
down_ws.append(G(f"{e}.down_proj.weight"))
|
||||
down_sfs.append(G(f"{e}.down_proj.weight_scale"))
|
||||
down_gss.append(G(f"{e}.down_proj.weight_scale_2"))
|
||||
|
||||
w13_w = torch.cat([torch.stack(gate_ws), torch.stack(up_ws)], dim=1)
|
||||
w13_sf = torch.cat([torch.stack(gate_sfs), torch.stack(up_sfs)], dim=1)
|
||||
w13_gs = torch.cat([torch.stack(gate_gss), torch.stack(up_gss)], dim=0)
|
||||
w2_w = torch.stack(down_ws)
|
||||
w2_sf = torch.stack(down_sfs)
|
||||
w2_gs = torch.stack(down_gss)
|
||||
|
||||
# Free per-expert lists
|
||||
del gate_ws, gate_sfs, gate_gss, up_ws, up_sfs, up_gss, down_ws, down_sfs, down_gss
|
||||
|
||||
moe_runner = CuTeDSLMoERunner(
|
||||
num_experts=num_moe_experts,
|
||||
hidden_size=H,
|
||||
intermediate_size=INTERMEDIATE,
|
||||
max_num_tokens=8192,
|
||||
top_k=TOPK,
|
||||
device=str(DEV),
|
||||
)
|
||||
|
||||
l1_fp4 = w13_w.view(torch.float4_e2m1fn_x2)
|
||||
l2_fp4 = w2_w.view(torch.float4_e2m1fn_x2)
|
||||
l1_sf = w13_sf.to(torch.float8_e4m3fn)
|
||||
l2_sf = w2_sf.to(torch.float8_e4m3fn)
|
||||
|
||||
moe_runner.prepare_weights_from_stacked(
|
||||
l1_fp4, l1_sf, w13_gs.flatten().tolist(),
|
||||
l2_fp4, l2_sf, w2_gs.flatten().tolist(),
|
||||
)
|
||||
|
||||
del w13_w, w13_sf, w13_gs, w2_w, w2_sf, w2_gs, l1_fp4, l2_fp4, l1_sf, l2_sf
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Shared expert
|
||||
se_gate_w = G(f"{m}.shared_experts.gate_proj.weight"); se_gate_sf = G(f"{m}.shared_experts.gate_proj.weight_scale"); se_gate_gs = G(f"{m}.shared_experts.gate_proj.weight_scale_2")
|
||||
se_up_w = G(f"{m}.shared_experts.up_proj.weight"); se_up_sf = G(f"{m}.shared_experts.up_proj.weight_scale"); se_up_gs = G(f"{m}.shared_experts.up_proj.weight_scale_2")
|
||||
se_down_w = G(f"{m}.shared_experts.down_proj.weight"); se_down_sf = G(f"{m}.shared_experts.down_proj.weight_scale"); se_down_gs = G(f"{m}.shared_experts.down_proj.weight_scale_2")
|
||||
|
||||
r_se_gate = make_runner(se_gate_w, se_gate_sf, se_gate_gs, H, se_gate_w.shape[0])
|
||||
r_se_up = make_runner(se_up_w, se_up_sf, se_up_gs, H, se_up_w.shape[0])
|
||||
r_se_down = make_runner(se_down_w, se_down_sf, se_down_gs, INTERMEDIATE, se_down_w.shape[0])
|
||||
|
||||
# Run the layer
|
||||
token_ids = torch.randint(1, 1000, (num_tokens,), dtype=torch.long, device=DEV)
|
||||
positions = torch.arange(num_tokens, dtype=torch.int64, device=DEV)
|
||||
|
||||
with torch.no_grad():
|
||||
hidden = emb[token_ids]
|
||||
|
||||
# ── Attention ──────────────────────────────────────────
|
||||
normed = rms(hidden, anorm, EPS)
|
||||
qa = r_qa.run(normed); kv = r_kv.run(normed)
|
||||
qa_n = rms(qa, qn, EPS); kv_n = rms(kv, kvn, EPS)
|
||||
q = r_qb.run(qa_n).view(num_tokens, NH, HD)
|
||||
q_rope = apply_gptj_rope(q, positions, cos_sin, NOPE, ROPE)
|
||||
kv_rope = apply_gptj_rope(kv_n.unsqueeze(1), positions, cos_sin, NOPE, ROPE).squeeze(1)
|
||||
|
||||
o_attn = causal_prefill_attention(q_rope, kv_rope, SCALE)
|
||||
o_inv = apply_inv_gptj_rope(o_attn, positions, cos_sin, NOPE, ROPE)
|
||||
o_grouped = o_inv.reshape(num_tokens, OG, HPG * HD).permute(1, 0, 2)
|
||||
z = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(num_tokens, OG * OL)
|
||||
attn_out = r_wob.run(z)
|
||||
|
||||
hidden = hidden + attn_out
|
||||
print(f" Layer {layer_id} ({lt}): after attention: amax={hidden.amax():.4f} NaN={torch.isnan(hidden).any()}")
|
||||
|
||||
# ── MoE ────────────────────────────────────────────────
|
||||
fnormed = rms(hidden, fnorm, EPS)
|
||||
|
||||
# Shared expert
|
||||
gate_out = r_se_gate.run(fnormed)
|
||||
up_out = r_se_up.run(fnormed)
|
||||
activated = F.silu(gate_out) * up_out
|
||||
se_out = r_se_down.run(activated)
|
||||
|
||||
# Routed experts (using MoE runner with subset of experts)
|
||||
topk_ids = torch.randint(0, num_moe_experts, (num_tokens, TOPK), device=DEV)
|
||||
topk_weights = torch.softmax(torch.randn(num_tokens, TOPK, device=DEV), dim=-1)
|
||||
moe_out = moe_runner.run(fnormed, topk_weights, topk_ids)
|
||||
|
||||
hidden = hidden + se_out + moe_out
|
||||
print(f" Layer {layer_id} ({lt}): after MoE: amax={hidden.amax():.4f} NaN={torch.isnan(hidden).any()}")
|
||||
|
||||
del r_qa, r_qb, r_kv, r_wob, r_se_gate, r_se_up, r_se_down, moe_runner
|
||||
torch.cuda.empty_cache()
|
||||
_cache.clear()
|
||||
|
||||
return not torch.isnan(hidden).any()
|
||||
|
||||
|
||||
def test_multi_layer():
|
||||
"""Test multiple layers chained together to see if NaN propagates."""
|
||||
emb = None
|
||||
|
||||
# Load embedding once
|
||||
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
||||
wm = json.load(f)["weight_map"]
|
||||
G = lambda k: P(k, wm, MODEL).to(DEV)
|
||||
emb = G("model.embed_tokens.weight")
|
||||
|
||||
num_tokens = 8
|
||||
token_ids = torch.randint(1, 1000, (num_tokens,), dtype=torch.long, device=DEV)
|
||||
hidden = emb[token_ids]
|
||||
|
||||
# Test just layers 0, 2, 60 (one of each type)
|
||||
# For each layer, do attention only (skip MoE to save memory)
|
||||
for layer_id in [0, 2, 60]:
|
||||
p = f"model.layers.{layer_id}"; a = f"{p}.self_attn"
|
||||
cr = 128 if layer_id == 0 else (0 if layer_id == 60 else 4)
|
||||
lt = f"C{cr}A" if cr > 1 else "SWA"
|
||||
|
||||
anorm = G(f"{p}.input_layernorm.weight")
|
||||
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
|
||||
woa = G(f"{a}.o_a_proj.weight")
|
||||
fnorm = G(f"{p}.post_attention_layernorm.weight")
|
||||
|
||||
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
|
||||
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
|
||||
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
|
||||
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
|
||||
|
||||
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
|
||||
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
|
||||
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
|
||||
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
|
||||
cos_sin = build_cos_sin(max_pos=4096).to(DEV)
|
||||
woa_3d = woa.view(OG, OL, HPG * HD)
|
||||
|
||||
# Shared expert
|
||||
m = f"{p}.mlp"
|
||||
se_gate_w = G(f"{m}.shared_experts.gate_proj.weight"); se_gate_sf = G(f"{m}.shared_experts.gate_proj.weight_scale"); se_gate_gs = G(f"{m}.shared_experts.gate_proj.weight_scale_2")
|
||||
se_up_w = G(f"{m}.shared_experts.up_proj.weight"); se_up_sf = G(f"{m}.shared_experts.up_proj.weight_scale"); se_up_gs = G(f"{m}.shared_experts.up_proj.weight_scale_2")
|
||||
se_down_w = G(f"{m}.shared_experts.down_proj.weight"); se_down_sf = G(f"{m}.shared_experts.down_proj.weight_scale"); se_down_gs = G(f"{m}.shared_experts.down_proj.weight_scale_2")
|
||||
|
||||
r_se_gate = make_runner(se_gate_w, se_gate_sf, se_gate_gs, H, se_gate_w.shape[0])
|
||||
r_se_up = make_runner(se_up_w, se_up_sf, se_up_gs, H, se_up_w.shape[0])
|
||||
r_se_down = make_runner(se_down_w, se_down_sf, se_down_gs, INTERMEDIATE, se_down_w.shape[0])
|
||||
|
||||
positions = torch.arange(num_tokens, dtype=torch.int64, device=DEV)
|
||||
|
||||
with torch.no_grad():
|
||||
# Attention
|
||||
normed = rms(hidden, anorm, EPS)
|
||||
qa = r_qa.run(normed); kv = r_kv.run(normed)
|
||||
qa_n = rms(qa, qn, EPS); kv_n = rms(kv, kvn, EPS)
|
||||
q = r_qb.run(qa_n).view(num_tokens, NH, HD)
|
||||
q_rope = apply_gptj_rope(q, positions, cos_sin, NOPE, ROPE)
|
||||
kv_rope = apply_gptj_rope(kv_n.unsqueeze(1), positions, cos_sin, NOPE, ROPE).squeeze(1)
|
||||
|
||||
o_attn = causal_prefill_attention(q_rope, kv_rope, SCALE)
|
||||
o_inv = apply_inv_gptj_rope(o_attn, positions, cos_sin, NOPE, ROPE)
|
||||
o_grouped = o_inv.reshape(num_tokens, OG, HPG * HD).permute(1, 0, 2)
|
||||
z = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(num_tokens, OG * OL)
|
||||
attn_out = r_wob.run(z)
|
||||
hidden = hidden + attn_out
|
||||
|
||||
# Shared expert MoE
|
||||
fnormed = rms(hidden, fnorm, EPS)
|
||||
gate_out = r_se_gate.run(fnormed)
|
||||
up_out = r_se_up.run(fnormed)
|
||||
activated = F.silu(gate_out) * up_out
|
||||
se_out = r_se_down.run(activated)
|
||||
hidden = hidden + se_out
|
||||
|
||||
attn_nan = torch.isnan(attn_out).any().item()
|
||||
moe_nan = torch.isnan(se_out).any().item()
|
||||
hs_nan = torch.isnan(hidden).any().item()
|
||||
print(f" Layer {layer_id} ({lt}): attn_nan={attn_nan} moe_nan={moe_nan} hidden_nan={hs_nan} amax={hidden.amax():.4f}")
|
||||
|
||||
if hs_nan:
|
||||
print(f" NaN detected at layer {layer_id}! Stopping.")
|
||||
break
|
||||
|
||||
del r_qa, r_qb, r_kv, r_wob, r_se_gate, r_se_up, r_se_down
|
||||
torch.cuda.empty_cache()
|
||||
_cache.clear()
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 70)
|
||||
print(" DeepSeek-V4 Full Layer NaN Test")
|
||||
print(" Tests attention + MoE to find where NaN originates")
|
||||
print("=" * 70)
|
||||
|
||||
print("\n=== Test 1: Single full layer (attention + MoE) ===")
|
||||
test_full_layer(layer_id=2, num_tokens=8, num_moe_experts=16)
|
||||
|
||||
print("\n=== Test 2: Multi-layer chain (attention + shared expert only) ===")
|
||||
test_multi_layer()
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user