diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py index e4462735..aad456a3 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/kernel.py @@ -28,7 +28,10 @@ def cutlass_nvfp4_blockscaled_gemm( """Single NVFP4 block-scaled GEMM using CUTLASS.""" if not _CUTLASS_AVAILABLE: raise RuntimeError("CUTLASS NVFP4 GEMM extension not available") - return _C.forward(A_packed, SFA, B_packed, SFB, M, N, K, alpha) + result = _C.forward(A_packed, SFA, B_packed, SFB, M, N, K, alpha) + if result < 0: + raise RuntimeError(f"CUTLASS NVFP4 GEMM failed with code {result} — can_implement may have rejected the problem (M={M} N={N} K={K})") + return result def cutlass_grouped_nvfp4_gemm( @@ -75,6 +78,15 @@ def cutlass_grouped_nvfp4_gemm( M_expert = token_indices.shape[0] + # DEBUG: verify data going into GEMM + if e == 0: + print(f"[GEMM-IN] expert={e} M={M_expert} N={N} K={K} " + f"w shape={expert_w.shape} w_sf shape={expert_w_sf.shape} " + f"w absmax={expert_w.view(torch.int8).abs().max().item()} " + f"w_sf range=[{expert_w_sf.to(torch.float32).min().item():.4e}, " + f"{expert_w_sf.to(torch.float32).max().item():.4e}] " + f"w_sf nonzero_frac={(expert_w_sf.view(torch.uint8) != 0).float().mean().item():.4f}") + # Run CUTLASS NVFP4 block-scaled GEMM expert_out = cutlass_nvfp4_blockscaled_gemm( expert_x, expert_x_sf,