diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 093af80a..63ab1c21 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -128,8 +128,7 @@ def nvfp4_mega_moe_l1( topk_ids, topk_weights, alpha=alpha, ) - if MEGA_MOE_DEBUG: - print(f"[L1-GEMM-OUT] amax={output.abs().max().item():.4e} mean={output.float().mean().item():.4e} nonzero_frac={(output != 0).float().mean().item():.4f}") + print(f"[L1-GEMM-OUT] amax={output.abs().max().item():.4e} mean={output.float().mean().item():.4e} nonzero_frac={(output != 0).float().mean().item():.4f}") return output # (num_tokens, 6144) bfloat16 @@ -328,6 +327,12 @@ def nvfp4_mega_moe_full( experts_start_idx = symm_buffer.experts_start_idx topk_ids_local = topk_ids - experts_start_idx + # Routing diagnostic (ungated — needed to diagnose zero-GEMM on specific ranks) + local_topk = (topk_ids >= experts_start_idx) & (topk_ids < experts_start_idx + num_experts_per_rank) + tokens_routed_locally = local_topk.any(dim=-1).sum().item() + print(f"[ROUTING] tokens_routed_local={tokens_routed_locally}/{topk_ids.shape[0]} " + f"unique_local_experts={local_topk.long().sum().item()}") + if MEGA_MOE_DEBUG: print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} " f"topk_ids={topk_ids.shape} topk_ids range: {topk_ids.min().item()}-{topk_ids.max().item()} "