diff --git a/docker-compose.yml b/docker-compose.yml index a9c2c547..a110a96a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,7 +9,7 @@ services: - OMP_NUM_THREADS=128 - CUDA_LAUNCH_BLOCKING=0 - TORCH_SHOW_CPP_STACKTRACES=0 - - MEGA_MOE_DEBUG=0 + - MEGA_MOE_DEBUG=1 - MEGA_MOE_STATIC=0 - NVFP4_DEBUG=0 - NVFP4_DEBUG_SYNC=0 diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py index 5f8315aa..26fbffa6 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py @@ -80,13 +80,21 @@ def cutlass_grouped_nvfp4_gemm( ) # (M_expert, N) bfloat16 # Check for CUDA errors after each expert GEMM - err = torch.cuda.current_stream().synchronize() + torch.cuda.current_stream().synchronize() - # Validate output + # Hard-fail on NaN/Inf — silent skip was hiding bugs if torch.isnan(expert_out).any() or torch.isinf(expert_out).any(): - if MEGA_MOE_DEBUG: - print(f"[cutlass_grouped_gemm] WARNING: expert {e} produced NaN/Inf, skipping") - continue + raise RuntimeError( + f"expert {e} of {num_experts}: GEMM emitted NaN/Inf. " + f"M={M_expert} N={N} K={K} | " + f"x abs range [{expert_x.view(torch.int8).abs().max().item()}], " + f"x_sf range [{expert_x_sf.to(torch.float32).min().item():.4e}, " + f"{expert_x_sf.to(torch.float32).max().item():.4e}], " + f"w_sf range [{expert_w_sf.to(torch.float32).min().item():.4e}, " + f"{expert_w_sf.to(torch.float32).max().item():.4e}], " + f"x_sf nan_frac={torch.isnan(expert_x_sf.to(torch.float32)).float().mean().item():.4f}, " + f"w_sf nan_frac={torch.isnan(expert_w_sf.to(torch.float32)).float().mean().item():.4f}" + ) # Scatter back with routing weights for t_idx, token_idx in enumerate(token_indices): diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 4e40fdaf..148db3ed 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -293,18 +293,36 @@ def nvfp4_mega_moe_full( f"local: {topk_ids_local.min().item()}-{topk_ids_local.max().item()} " f"l1_w={l1_w.shape} l2_w={l2_w.shape}") + # NaN-trace: check activation scales at L1 input + if MEGA_MOE_DEBUG: + x_sf_f32 = x_sf.to(torch.float32) + print(f"[L1-in] x_sf nan={torch.isnan(x_sf_f32).any().item()} " + f"inf={torch.isinf(x_sf_f32).any().item()} " + f"min={x_sf_f32.min().item():.4e} max={x_sf_f32.max().item():.4e}") + # Step 2: L1 GEMM (native NVFP4 block-scaled MMA) l1_output = nvfp4_mega_moe_l1( x_fp4, x_sf, l1_w, l1_sf, topk_ids_local, topk_weights, num_experts_per_rank, ) + # NaN-trace: check L1 output + if MEGA_MOE_DEBUG: + print(f"[L1-out] nan={torch.isnan(l1_output).any().item()} " + f"inf={torch.isinf(l1_output).any().item()} " + f"abs_max={l1_output.abs().max().item():.4e}") + # Step 3: SiLU + Mul gate, up = l1_output.chunk(2, dim=-1) activated = torch.nn.functional.silu(gate) * up if activation_clamp is not None: activated = activated.clamp(max=activation_clamp) + # NaN-trace: check SiLU output + if MEGA_MOE_DEBUG: + print(f"[silu] nan={torch.isnan(activated).any().item()} " + f"abs_max={activated.abs().max().item():.4e}") + # Step 4: Quantize L1 output → FP4 l1_fp4, l1_sf_out = stage_activation(activated) @@ -314,5 +332,10 @@ def nvfp4_mega_moe_full( topk_ids_local, topk_weights, num_experts_per_rank, ) + # NaN-trace: check L2 output + if MEGA_MOE_DEBUG: + print(f"[L2-out] nan={torch.isnan(l2_output).any().item()} " + f"abs_max={l2_output.abs().max().item():.4e}") + # Step 6: Write to output (caller handles cross-rank all-reduce) y.copy_(l2_output)