debug empty output

This commit is contained in:
2026-05-14 22:13:32 +00:00
parent 09d1307d78
commit ce4c4b6fcb
3 changed files with 37 additions and 6 deletions

View File

@@ -9,7 +9,7 @@ services:
- OMP_NUM_THREADS=128
- CUDA_LAUNCH_BLOCKING=0
- TORCH_SHOW_CPP_STACKTRACES=0
- MEGA_MOE_DEBUG=0
- MEGA_MOE_DEBUG=1
- MEGA_MOE_STATIC=0
- NVFP4_DEBUG=0
- NVFP4_DEBUG_SYNC=0

View File

@@ -80,13 +80,21 @@ def cutlass_grouped_nvfp4_gemm(
) # (M_expert, N) bfloat16
# Check for CUDA errors after each expert GEMM
err = torch.cuda.current_stream().synchronize()
torch.cuda.current_stream().synchronize()
# Validate output
# Hard-fail on NaN/Inf — silent skip was hiding bugs
if torch.isnan(expert_out).any() or torch.isinf(expert_out).any():
if MEGA_MOE_DEBUG:
print(f"[cutlass_grouped_gemm] WARNING: expert {e} produced NaN/Inf, skipping")
continue
raise RuntimeError(
f"expert {e} of {num_experts}: GEMM emitted NaN/Inf. "
f"M={M_expert} N={N} K={K} | "
f"x abs range [{expert_x.view(torch.int8).abs().max().item()}], "
f"x_sf range [{expert_x_sf.to(torch.float32).min().item():.4e}, "
f"{expert_x_sf.to(torch.float32).max().item():.4e}], "
f"w_sf range [{expert_w_sf.to(torch.float32).min().item():.4e}, "
f"{expert_w_sf.to(torch.float32).max().item():.4e}], "
f"x_sf nan_frac={torch.isnan(expert_x_sf.to(torch.float32)).float().mean().item():.4f}, "
f"w_sf nan_frac={torch.isnan(expert_w_sf.to(torch.float32)).float().mean().item():.4f}"
)
# Scatter back with routing weights
for t_idx, token_idx in enumerate(token_indices):

View File

@@ -293,18 +293,36 @@ def nvfp4_mega_moe_full(
f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} "
f"l1_w={l1_w.shape} l2_w={l2_w.shape}")
# NaN-trace: check activation scales at L1 input
if MEGA_MOE_DEBUG:
x_sf_f32 = x_sf.to(torch.float32)
print(f"[L1-in] x_sf nan={torch.isnan(x_sf_f32).any().item()} "
f"inf={torch.isinf(x_sf_f32).any().item()} "
f"min={x_sf_f32.min().item():.4e} max={x_sf_f32.max().item():.4e}")
# Step 2: L1 GEMM (native NVFP4 block-scaled MMA)
l1_output = nvfp4_mega_moe_l1(
x_fp4, x_sf, l1_w, l1_sf,
topk_ids_local, topk_weights, num_experts_per_rank,
)
# NaN-trace: check L1 output
if MEGA_MOE_DEBUG:
print(f"[L1-out] nan={torch.isnan(l1_output).any().item()} "
f"inf={torch.isinf(l1_output).any().item()} "
f"abs_max={l1_output.abs().max().item():.4e}")
# Step 3: SiLU + Mul
gate, up = l1_output.chunk(2, dim=-1)
activated = torch.nn.functional.silu(gate) * up
if activation_clamp is not None:
activated = activated.clamp(max=activation_clamp)
# NaN-trace: check SiLU output
if MEGA_MOE_DEBUG:
print(f"[silu] nan={torch.isnan(activated).any().item()} "
f"abs_max={activated.abs().max().item():.4e}")
# Step 4: Quantize L1 output → FP4
l1_fp4, l1_sf_out = stage_activation(activated)
@@ -314,5 +332,10 @@ def nvfp4_mega_moe_full(
topk_ids_local, topk_weights, num_experts_per_rank,
)
# NaN-trace: check L2 output
if MEGA_MOE_DEBUG:
print(f"[L2-out] nan={torch.isnan(l2_output).any().item()} "
f"abs_max={l2_output.abs().max().item():.4e}")
# Step 6: Write to output (caller handles cross-rank all-reduce)
y.copy_(l2_output)