226 lines
8.5 KiB
Python
226 lines
8.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
DeepSeek-V4 MoE NaN Reproduction Test
|
|
|
|
Finds where NaN originates in the MoE forward pass.
|
|
Tests individual experts (gate+up+down) with the CuTeDSL NVFP4 linear runner.
|
|
Then tests the grouped GEMM MoE runner with stacked weights.
|
|
|
|
Key insight: DeepSeek-V4 is a MegaMoE with 384 experts.
|
|
The NaN might come from:
|
|
1. Weight loading / quantization
|
|
2. Activation quantization (quantize_activation_nvfp4)
|
|
3. The grouped GEMM kernel
|
|
4. The combine/scatter step
|
|
|
|
Usage (on B200):
|
|
cd /root/nvfp4-megamoe-kernel
|
|
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_moe_nan_b200.py
|
|
"""
|
|
|
|
import sys, os, json, torch, torch.nn.functional as F
|
|
from safetensors import safe_open
|
|
|
|
REPO = "/root/nvfp4-megamoe-kernel"
|
|
sys.path.insert(0, REPO)
|
|
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
DEV = "cuda:0"
|
|
|
|
H = 7168
|
|
INTERMEDIATE = 3072 # DeepSeek-V4 MoE intermediate
|
|
NUM_EXPERTS = 384
|
|
TOPK = 6
|
|
EPS = 1e-6
|
|
|
|
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
|
|
|
|
_cache = {}
|
|
def P(k, wm, md):
|
|
if k in _cache: return _cache[k]
|
|
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
|
t = f.get_tensor(k)
|
|
_cache[k] = t
|
|
return t
|
|
|
|
def rms(x, w, eps=1e-6):
|
|
v = x.float().pow(2).mean(-1, keepdim=True)
|
|
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
|
|
|
def make_runner(w, sf, gs_t, inf, outf):
|
|
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
|
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
|
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
|
s = s.permute(1,0).contiguous()
|
|
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
|
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
|
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
|
r.finalize_weights(); r._ensure_initialized()
|
|
return r
|
|
|
|
|
|
def test_single_expert(layer_id=2, expert_id=0):
|
|
"""Test a single expert's gate+up+down with CuTeDSL NVFP4 linear."""
|
|
torch.cuda.set_device(0)
|
|
torch.cuda.empty_cache()
|
|
_cache.clear()
|
|
|
|
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
|
wm = json.load(f)["weight_map"]
|
|
G = lambda k: P(k, wm, MODEL).to(DEV)
|
|
|
|
p = f"model.layers.{layer_id}"
|
|
m = f"{p}.mlp"
|
|
e = f"{m}.experts.{expert_id}"
|
|
|
|
emb = G("model.embed_tokens.weight")
|
|
fnorm = G(f"{p}.post_attention_layernorm.weight")
|
|
|
|
# Load expert weights
|
|
gate_w = G(f"{e}.gate_proj.weight"); gate_sf = G(f"{e}.gate_proj.weight_scale"); gate_gs = G(f"{e}.gate_proj.weight_scale_2")
|
|
up_w = G(f"{e}.up_proj.weight"); up_sf = G(f"{e}.up_proj.weight_scale"); up_gs = G(f"{e}.up_proj.weight_scale_2")
|
|
down_w = G(f"{e}.down_proj.weight"); down_sf = G(f"{e}.down_proj.weight_scale"); down_gs = G(f"{e}.down_proj.weight_scale_2")
|
|
|
|
print(f" Expert {expert_id}:")
|
|
print(f" gate: shape={gate_w.shape} dtype={gate_w.dtype} sf_shape={gate_sf.shape} gs={gate_gs.tolist()}")
|
|
print(f" up: shape={up_w.shape} dtype={up_w.dtype}")
|
|
print(f" down: shape={down_w.shape} dtype={down_w.dtype}")
|
|
print(f" gate NaN: {torch.isnan(gate_w.float()).any()}")
|
|
print(f" gate_gs NaN: {torch.isnan(gate_gs).any()}")
|
|
print(f" gate input_scale: exists={f'{e}.gate_proj.input_scale' in wm}")
|
|
|
|
# Check for zero or extreme gs values
|
|
for name, gs in [("gate", gate_gs), ("up", up_gs), ("down", down_gs)]:
|
|
if gs.numel() > 0:
|
|
print(f" {name} gs: min={gs.min().item():.6f} max={gs.max().item():.6f}")
|
|
if gs.min().item() == 0:
|
|
print(f" WARNING: {name} gs has zero value — will cause division by zero!")
|
|
|
|
r_gate = make_runner(gate_w, gate_sf, gate_gs, H, gate_w.shape[0])
|
|
r_up = make_runner(up_w, up_sf, up_gs, H, up_w.shape[0])
|
|
r_down = make_runner(down_w, down_sf, down_gs, INTERMEDIATE, down_w.shape[0])
|
|
|
|
# Test with various token counts
|
|
for num_tokens in [1, 4, 8, 16]:
|
|
token_ids = torch.randint(1, 1000, (num_tokens,), dtype=torch.long, device=DEV)
|
|
hidden = emb[token_ids]
|
|
normed = rms(hidden, fnorm, EPS)
|
|
|
|
with torch.no_grad():
|
|
gate_out = r_gate.run(normed)
|
|
up_out = r_up.run(normed)
|
|
|
|
# Check gate and up
|
|
gate_nan = torch.isnan(gate_out).any().item()
|
|
up_nan = torch.isnan(up_out).any().item()
|
|
|
|
if gate_nan or up_nan:
|
|
print(f" {num_tokens} tokens: gate NaN={gate_nan} up NaN={up_nan}")
|
|
# Find which row has NaN
|
|
gate_nan_rows = torch.isnan(gate_out).any(dim=1).nonzero().flatten().tolist()
|
|
print(f" Gate NaN rows: {gate_nan_rows}")
|
|
continue
|
|
|
|
# SiLU activation
|
|
activated = F.silu(gate_out) * up_out
|
|
act_nan = torch.isnan(activated).any().item()
|
|
|
|
if act_nan:
|
|
print(f" {num_tokens} tokens: NaN after SiLU activation!")
|
|
continue
|
|
|
|
down_out = r_down.run(activated)
|
|
down_nan = torch.isnan(down_out).any().item()
|
|
|
|
if down_nan:
|
|
print(f" {num_tokens} tokens: down NaN={down_nan}")
|
|
continue
|
|
|
|
print(f" {num_tokens} tokens: amax={down_out.amax():.4f} OK")
|
|
|
|
del r_gate, r_up, r_down
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def test_quantize_activation():
|
|
"""Test the activation quantization used by the MoE grouped GEMM."""
|
|
from cutedsl.nvfp4_linear import quantize_activation_nvfp4
|
|
|
|
torch.cuda.set_device(0)
|
|
|
|
for num_tokens in [1, 4, 8, 16]:
|
|
# Create realistic input (after SiLU * up)
|
|
x = torch.randn(num_tokens, INTERMEDIATE, dtype=torch.bfloat16, device=DEV)
|
|
|
|
# quantize_activation_nvfp4 returns (x_sf, x_gs) or similar
|
|
# The grouped GEMM needs quantized activation as input
|
|
try:
|
|
result = quantize_activation_nvfp4(x, num_tokens)
|
|
if isinstance(result, tuple):
|
|
for i, r in enumerate(result):
|
|
if r is not None and r.is_floating_point():
|
|
print(f" {num_tokens} tokens: quantize result[{i}] NaN={torch.isnan(r).any()}")
|
|
else:
|
|
print(f" {num_tokens} tokens: quantize result NaN={torch.isnan(result).any()}")
|
|
except Exception as e:
|
|
print(f" {num_tokens} tokens: quantize failed: {e}")
|
|
|
|
print()
|
|
|
|
|
|
def test_grouped_gemm_shapes():
|
|
"""Test the CuTeDSL grouped GEMM with MegaMoE-like shapes."""
|
|
from cutedsl.moe import run_nvfp4_grouped_gemm
|
|
|
|
torch.cuda.set_device(0)
|
|
|
|
# Create simple test: 4 experts, 8 tokens, top-2
|
|
num_experts = 4
|
|
num_tokens = 8
|
|
hidden_size = 512 # Small for testing
|
|
intermediate_size = 256
|
|
|
|
# Allocate weight tensors (random)
|
|
# L1: (num_experts, 2*intermediate_size, hidden_size//2) fp4
|
|
# L2: (num_experts, hidden_size, intermediate_size//2) fp4
|
|
l1_shape = (num_experts, 2 * intermediate_size, hidden_size // 2)
|
|
l2_shape = (num_experts, hidden_size, intermediate_size // 2)
|
|
|
|
print(f" Testing grouped GEMM with:")
|
|
print(f" num_experts={num_experts}, num_tokens={num_tokens}")
|
|
print(f" l1_shape={l1_shape}, l2_shape={l2_shape}")
|
|
|
|
# This test just checks if the kernel can handle various expert distributions
|
|
# without NaN. The actual weight values don't matter for NaN detection.
|
|
print(f" (Skipping — requires proper weight packing from vLLM model loader)")
|
|
print(f" The CuTeDSL grouped GEMM needs weights in a specific packed format")
|
|
print(f" that the vLLM model loader creates during model initialization.")
|
|
print()
|
|
|
|
|
|
def main():
|
|
print("=" * 70)
|
|
print(" DeepSeek-V4 MoE NaN Reproduction Test")
|
|
print(" Finds where NaN originates in the MoE forward pass")
|
|
print("=" * 70)
|
|
|
|
print("\n=== Test 1: Single expert gate+up+down ===")
|
|
for expert_id in [0, 1, 100, 383]:
|
|
test_single_expert(layer_id=2, expert_id=expert_id)
|
|
_cache.clear()
|
|
|
|
print("\n=== Test 2: Activation quantization ===")
|
|
test_quantize_activation()
|
|
|
|
print("\n=== Test 3: Grouped GEMM shapes ===")
|
|
test_grouped_gemm_shapes()
|
|
|
|
print(f"\n{'='*70}")
|
|
print(f" Summary: If single experts produce NaN, the issue is in weight")
|
|
print(f" loading or the CuTeDSL NVFP4 linear kernel. If they're fine,")
|
|
print(f" the NaN comes from the grouped GEMM or the combine step.")
|
|
print(f"{'='*70}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|