[Bugfix][Build/CI] Fix sparse CUTLASS compilation on CUDA [12.0, 12.2) (#11311)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith
2024-12-18 21:43:30 -05:00
committed by GitHub
parent fdea8ec167
commit 5a9da2e6e9
12 changed files with 89 additions and 20 deletions

View File

@@ -5,7 +5,18 @@
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) {
// sparse CUTLASS kernels need at least
// CUDA 12.2 and SM90 (Hopper)
#if defined CUDA_VERSION
return CUDA_VERSION >= 12020 && cuda_device_capability >= 90;
#endif
return false;
}
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& e,
@@ -43,7 +54,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
int32_t version_num = get_sm_version_num();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
bias);