[Build/BugFix] Fix hopper 12.8 build (#14354)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-03-08 03:11:56 -05:00
committed by GitHub
parent be0b399d74
commit 7caff01a7b
4 changed files with 98 additions and 75 deletions

View File

@@ -23,12 +23,15 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
@@ -60,7 +63,7 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
@@ -121,26 +124,21 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined CUDA_VERSION && CUDA_VERSION < 12080
if (version_num >= 90 && version_num < 100) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
}
#else
if (version_num >= 90 && version_num < 100) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
} else if (version_num >= 100) {
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
if (version_num >= 100) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#endif
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90 && version_num < 100) {
// Hopper
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
@@ -211,7 +209,7 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;