debug empty output
This commit is contained in:
@@ -9,7 +9,7 @@ services:
|
||||
- OMP_NUM_THREADS=128
|
||||
- CUDA_LAUNCH_BLOCKING=0
|
||||
- TORCH_SHOW_CPP_STACKTRACES=0
|
||||
- MEGA_MOE_DEBUG=0
|
||||
- MEGA_MOE_DEBUG=1
|
||||
- MEGA_MOE_STATIC=0
|
||||
- NVFP4_DEBUG=0
|
||||
- NVFP4_DEBUG_SYNC=0
|
||||
|
||||
@@ -80,13 +80,21 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
) # (M_expert, N) bfloat16
|
||||
|
||||
# Check for CUDA errors after each expert GEMM
|
||||
err = torch.cuda.current_stream().synchronize()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# Validate output
|
||||
# Hard-fail on NaN/Inf — silent skip was hiding bugs
|
||||
if torch.isnan(expert_out).any() or torch.isinf(expert_out).any():
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[cutlass_grouped_gemm] WARNING: expert {e} produced NaN/Inf, skipping")
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"expert {e} of {num_experts}: GEMM emitted NaN/Inf. "
|
||||
f"M={M_expert} N={N} K={K} | "
|
||||
f"x abs range [{expert_x.view(torch.int8).abs().max().item()}], "
|
||||
f"x_sf range [{expert_x_sf.to(torch.float32).min().item():.4e}, "
|
||||
f"{expert_x_sf.to(torch.float32).max().item():.4e}], "
|
||||
f"w_sf range [{expert_w_sf.to(torch.float32).min().item():.4e}, "
|
||||
f"{expert_w_sf.to(torch.float32).max().item():.4e}], "
|
||||
f"x_sf nan_frac={torch.isnan(expert_x_sf.to(torch.float32)).float().mean().item():.4f}, "
|
||||
f"w_sf nan_frac={torch.isnan(expert_w_sf.to(torch.float32)).float().mean().item():.4f}"
|
||||
)
|
||||
|
||||
# Scatter back with routing weights
|
||||
for t_idx, token_idx in enumerate(token_indices):
|
||||
|
||||
@@ -293,18 +293,36 @@ def nvfp4_mega_moe_full(
|
||||
f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} "
|
||||
f"l1_w={l1_w.shape} l2_w={l2_w.shape}")
|
||||
|
||||
# NaN-trace: check activation scales at L1 input
|
||||
if MEGA_MOE_DEBUG:
|
||||
x_sf_f32 = x_sf.to(torch.float32)
|
||||
print(f"[L1-in] x_sf nan={torch.isnan(x_sf_f32).any().item()} "
|
||||
f"inf={torch.isinf(x_sf_f32).any().item()} "
|
||||
f"min={x_sf_f32.min().item():.4e} max={x_sf_f32.max().item():.4e}")
|
||||
|
||||
# Step 2: L1 GEMM (native NVFP4 block-scaled MMA)
|
||||
l1_output = nvfp4_mega_moe_l1(
|
||||
x_fp4, x_sf, l1_w, l1_sf,
|
||||
topk_ids_local, topk_weights, num_experts_per_rank,
|
||||
)
|
||||
|
||||
# NaN-trace: check L1 output
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[L1-out] nan={torch.isnan(l1_output).any().item()} "
|
||||
f"inf={torch.isinf(l1_output).any().item()} "
|
||||
f"abs_max={l1_output.abs().max().item():.4e}")
|
||||
|
||||
# Step 3: SiLU + Mul
|
||||
gate, up = l1_output.chunk(2, dim=-1)
|
||||
activated = torch.nn.functional.silu(gate) * up
|
||||
if activation_clamp is not None:
|
||||
activated = activated.clamp(max=activation_clamp)
|
||||
|
||||
# NaN-trace: check SiLU output
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[silu] nan={torch.isnan(activated).any().item()} "
|
||||
f"abs_max={activated.abs().max().item():.4e}")
|
||||
|
||||
# Step 4: Quantize L1 output → FP4
|
||||
l1_fp4, l1_sf_out = stage_activation(activated)
|
||||
|
||||
@@ -314,5 +332,10 @@ def nvfp4_mega_moe_full(
|
||||
topk_ids_local, topk_weights, num_experts_per_rank,
|
||||
)
|
||||
|
||||
# NaN-trace: check L2 output
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[L2-out] nan={torch.isnan(l2_output).any().item()} "
|
||||
f"abs_max={l2_output.abs().max().item():.4e}")
|
||||
|
||||
# Step 6: Write to output (caller handles cross-rank all-reduce)
|
||||
y.copy_(l2_output)
|
||||
|
||||
Reference in New Issue
Block a user