From 2ee4e267729f8647f3c3cd7cc8fe80c5076bdb6e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 10:18:36 +0000 Subject: [PATCH] Fix: remove compile-time SM100 guard from pytorch binding, use runtime check instead --- .../cutlass_nvfp4_gemm/pytorch_binding.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp index b4e33d33..c9b3f4ba 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp @@ -3,8 +3,6 @@ #include #include -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - extern "C" int cutlass_nvfp4_gemm_run( const void* A_ptr, const void* SFA_ptr, const void* B_ptr, const void* SFB_ptr, @@ -14,8 +12,6 @@ extern "C" int cutlass_nvfp4_gemm_run( cudaStream_t stream ); -#endif - torch::Tensor cutlass_nvfp4_gemm_forward( torch::Tensor A_packed, // (M, K_half) int8 packed E2M1 torch::Tensor SFA, // UE4M3 block scales for A @@ -23,12 +19,15 @@ torch::Tensor cutlass_nvfp4_gemm_forward( torch::Tensor SFB, // UE4M3 block scales for B int64_t M, int64_t N, int64_t K // K in FP4 elements (actual K = K * 2 since 2 FP4 per byte) ) { -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + // Runtime check for SM100 + auto props = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK(props->major >= 10, "CUTLASS NVFP4 GEMM requires SM100 (Blackwell) architecture, got SM", props->major, props->minor); + auto D = torch::empty({M, N}, torch::dtype(torch::kBF16).device(A_packed.device())); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - cutlass_nvfp4_gemm_run( + int rc = cutlass_nvfp4_gemm_run( A_packed.data_ptr(), SFA.data_ptr(), B_packed.data_ptr(), SFB.data_ptr(), D.data_ptr(), @@ -37,10 +36,9 @@ torch::Tensor cutlass_nvfp4_gemm_forward( stream ); + TORCH_CHECK(rc == 0, "CUTLASS NVFP4 GEMM failed with error code ", rc); + return D; -#else - TORCH_CHECK(false, "CUTLASS NVFP4 GEMM requires SM100 (Blackwell) architecture"); -#endif } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {