[Bugfix][Build/CI] Fix sparse CUTLASS compilation on CUDA [12.0, 12.2) (#11311)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
fdea8ec167
commit
5a9da2e6e9
@@ -5,7 +5,18 @@
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) {
|
||||
// sparse CUTLASS kernels need at least
|
||||
// CUDA 12.2 and SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
return CUDA_VERSION >= 12020 && cuda_device_capability >= 90;
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& e,
|
||||
@@ -43,7 +54,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
|
||||
bias);
|
||||
|
||||
Reference in New Issue
Block a user