Files
nvfp4-megamoe-kernel/tests/test_moe_nan_b200.py

226 lines
8.5 KiB
Python

#!/usr/bin/env python3
"""
DeepSeek-V4 MoE NaN Reproduction Test
Finds where NaN originates in the MoE forward pass.
Tests individual experts (gate+up+down) with the CuTeDSL NVFP4 linear runner.
Then tests the grouped GEMM MoE runner with stacked weights.
Key insight: DeepSeek-V4 is a MegaMoE with 384 experts.
The NaN might come from:
1. Weight loading / quantization
2. Activation quantization (quantize_activation_nvfp4)
3. The grouped GEMM kernel
4. The combine/scatter step
Usage (on B200):
cd /root/nvfp4-megamoe-kernel
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_moe_nan_b200.py
"""
import sys, os, json, torch, torch.nn.functional as F
from safetensors import safe_open
REPO = "/root/nvfp4-megamoe-kernel"
sys.path.insert(0, REPO)
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEV = "cuda:0"
H = 7168
INTERMEDIATE = 3072 # DeepSeek-V4 MoE intermediate
NUM_EXPERTS = 384
TOPK = 6
EPS = 1e-6
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
_cache = {}
def P(k, wm, md):
if k in _cache: return _cache[k]
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
t = f.get_tensor(k)
_cache[k] = t
return t
def rms(x, w, eps=1e-6):
v = x.float().pow(2).mean(-1, keepdim=True)
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
def make_runner(w, sf, gs_t, inf, outf):
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
s = s.permute(1,0).contiguous()
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
r.finalize_weights(); r._ensure_initialized()
return r
def test_single_expert(layer_id=2, expert_id=0):
"""Test a single expert's gate+up+down with CuTeDSL NVFP4 linear."""
torch.cuda.set_device(0)
torch.cuda.empty_cache()
_cache.clear()
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
wm = json.load(f)["weight_map"]
G = lambda k: P(k, wm, MODEL).to(DEV)
p = f"model.layers.{layer_id}"
m = f"{p}.mlp"
e = f"{m}.experts.{expert_id}"
emb = G("model.embed_tokens.weight")
fnorm = G(f"{p}.post_attention_layernorm.weight")
# Load expert weights
gate_w = G(f"{e}.gate_proj.weight"); gate_sf = G(f"{e}.gate_proj.weight_scale"); gate_gs = G(f"{e}.gate_proj.weight_scale_2")
up_w = G(f"{e}.up_proj.weight"); up_sf = G(f"{e}.up_proj.weight_scale"); up_gs = G(f"{e}.up_proj.weight_scale_2")
down_w = G(f"{e}.down_proj.weight"); down_sf = G(f"{e}.down_proj.weight_scale"); down_gs = G(f"{e}.down_proj.weight_scale_2")
print(f" Expert {expert_id}:")
print(f" gate: shape={gate_w.shape} dtype={gate_w.dtype} sf_shape={gate_sf.shape} gs={gate_gs.tolist()}")
print(f" up: shape={up_w.shape} dtype={up_w.dtype}")
print(f" down: shape={down_w.shape} dtype={down_w.dtype}")
print(f" gate NaN: {torch.isnan(gate_w.float()).any()}")
print(f" gate_gs NaN: {torch.isnan(gate_gs).any()}")
print(f" gate input_scale: exists={f'{e}.gate_proj.input_scale' in wm}")
# Check for zero or extreme gs values
for name, gs in [("gate", gate_gs), ("up", up_gs), ("down", down_gs)]:
if gs.numel() > 0:
print(f" {name} gs: min={gs.min().item():.6f} max={gs.max().item():.6f}")
if gs.min().item() == 0:
print(f" WARNING: {name} gs has zero value — will cause division by zero!")
r_gate = make_runner(gate_w, gate_sf, gate_gs, H, gate_w.shape[0])
r_up = make_runner(up_w, up_sf, up_gs, H, up_w.shape[0])
r_down = make_runner(down_w, down_sf, down_gs, INTERMEDIATE, down_w.shape[0])
# Test with various token counts
for num_tokens in [1, 4, 8, 16]:
token_ids = torch.randint(1, 1000, (num_tokens,), dtype=torch.long, device=DEV)
hidden = emb[token_ids]
normed = rms(hidden, fnorm, EPS)
with torch.no_grad():
gate_out = r_gate.run(normed)
up_out = r_up.run(normed)
# Check gate and up
gate_nan = torch.isnan(gate_out).any().item()
up_nan = torch.isnan(up_out).any().item()
if gate_nan or up_nan:
print(f" {num_tokens} tokens: gate NaN={gate_nan} up NaN={up_nan}")
# Find which row has NaN
gate_nan_rows = torch.isnan(gate_out).any(dim=1).nonzero().flatten().tolist()
print(f" Gate NaN rows: {gate_nan_rows}")
continue
# SiLU activation
activated = F.silu(gate_out) * up_out
act_nan = torch.isnan(activated).any().item()
if act_nan:
print(f" {num_tokens} tokens: NaN after SiLU activation!")
continue
down_out = r_down.run(activated)
down_nan = torch.isnan(down_out).any().item()
if down_nan:
print(f" {num_tokens} tokens: down NaN={down_nan}")
continue
print(f" {num_tokens} tokens: amax={down_out.amax():.4f} OK")
del r_gate, r_up, r_down
torch.cuda.empty_cache()
def test_quantize_activation():
"""Test the activation quantization used by the MoE grouped GEMM."""
from cutedsl.nvfp4_linear import quantize_activation_nvfp4
torch.cuda.set_device(0)
for num_tokens in [1, 4, 8, 16]:
# Create realistic input (after SiLU * up)
x = torch.randn(num_tokens, INTERMEDIATE, dtype=torch.bfloat16, device=DEV)
# quantize_activation_nvfp4 returns (x_sf, x_gs) or similar
# The grouped GEMM needs quantized activation as input
try:
result = quantize_activation_nvfp4(x, num_tokens)
if isinstance(result, tuple):
for i, r in enumerate(result):
if r is not None and r.is_floating_point():
print(f" {num_tokens} tokens: quantize result[{i}] NaN={torch.isnan(r).any()}")
else:
print(f" {num_tokens} tokens: quantize result NaN={torch.isnan(result).any()}")
except Exception as e:
print(f" {num_tokens} tokens: quantize failed: {e}")
print()
def test_grouped_gemm_shapes():
"""Test the CuTeDSL grouped GEMM with MegaMoE-like shapes."""
from cutedsl.moe import run_nvfp4_grouped_gemm
torch.cuda.set_device(0)
# Create simple test: 4 experts, 8 tokens, top-2
num_experts = 4
num_tokens = 8
hidden_size = 512 # Small for testing
intermediate_size = 256
# Allocate weight tensors (random)
# L1: (num_experts, 2*intermediate_size, hidden_size//2) fp4
# L2: (num_experts, hidden_size, intermediate_size//2) fp4
l1_shape = (num_experts, 2 * intermediate_size, hidden_size // 2)
l2_shape = (num_experts, hidden_size, intermediate_size // 2)
print(f" Testing grouped GEMM with:")
print(f" num_experts={num_experts}, num_tokens={num_tokens}")
print(f" l1_shape={l1_shape}, l2_shape={l2_shape}")
# This test just checks if the kernel can handle various expert distributions
# without NaN. The actual weight values don't matter for NaN detection.
print(f" (Skipping — requires proper weight packing from vLLM model loader)")
print(f" The CuTeDSL grouped GEMM needs weights in a specific packed format")
print(f" that the vLLM model loader creates during model initialization.")
print()
def main():
print("=" * 70)
print(" DeepSeek-V4 MoE NaN Reproduction Test")
print(" Finds where NaN originates in the MoE forward pass")
print("=" * 70)
print("\n=== Test 1: Single expert gate+up+down ===")
for expert_id in [0, 1, 100, 383]:
test_single_expert(layer_id=2, expert_id=expert_id)
_cache.clear()
print("\n=== Test 2: Activation quantization ===")
test_quantize_activation()
print("\n=== Test 3: Grouped GEMM shapes ===")
test_grouped_gemm_shapes()
print(f"\n{'='*70}")
print(f" Summary: If single experts produce NaN, the issue is in weight")
print(f" loading or the CuTeDSL NVFP4 linear kernel. If they're fine,")
print(f" the NaN comes from the grouped GEMM or the combine step.")
print(f"{'='*70}")
if __name__ == "__main__":
main()