Fix: remove compile-time SM100 guard from pytorch binding, use runtime check instead

This commit is contained in:
2026-05-14 10:18:36 +00:00
parent 540e68593f
commit 2ee4e26772

View File

@@ -3,8 +3,6 @@
#include <torch/extension.h>
#include <cuda_runtime.h>
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
extern "C" int cutlass_nvfp4_gemm_run(
const void* A_ptr, const void* SFA_ptr,
const void* B_ptr, const void* SFB_ptr,
@@ -14,8 +12,6 @@ extern "C" int cutlass_nvfp4_gemm_run(
cudaStream_t stream
);
#endif
torch::Tensor cutlass_nvfp4_gemm_forward(
torch::Tensor A_packed, // (M, K_half) int8 packed E2M1
torch::Tensor SFA, // UE4M3 block scales for A
@@ -23,12 +19,15 @@ torch::Tensor cutlass_nvfp4_gemm_forward(
torch::Tensor SFB, // UE4M3 block scales for B
int64_t M, int64_t N, int64_t K // K in FP4 elements (actual K = K * 2 since 2 FP4 per byte)
) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Runtime check for SM100
auto props = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(props->major >= 10, "CUTLASS NVFP4 GEMM requires SM100 (Blackwell) architecture, got SM", props->major, props->minor);
auto D = torch::empty({M, N}, torch::dtype(torch::kBF16).device(A_packed.device()));
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cutlass_nvfp4_gemm_run(
int rc = cutlass_nvfp4_gemm_run(
A_packed.data_ptr(), SFA.data_ptr(),
B_packed.data_ptr(), SFB.data_ptr(),
D.data_ptr(),
@@ -37,10 +36,9 @@ torch::Tensor cutlass_nvfp4_gemm_forward(
stream
);
TORCH_CHECK(rc == 0, "CUTLASS NVFP4 GEMM failed with error code ", rc);
return D;
#else
TORCH_CHECK(false, "CUTLASS NVFP4 GEMM requires SM100 (Blackwell) architecture");
#endif
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {