add cutlass support for blackwell fp8 gemm (#13798)

This commit is contained in:
kushanam
2025-03-04 07:55:07 -08:00
committed by GitHub
parent b3cf368d79
commit f89978ad7c
11 changed files with 272 additions and 65 deletions

View File

@@ -29,6 +29,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
#endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
@@ -86,7 +91,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
return CUDA_VERSION >= 12000;
}
#endif
@@ -120,10 +125,22 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
#if defined CUDA_VERSION && CUDA_VERSION < 12080
if (version_num >= 90 && version_num < 100) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
}
#else
if (version_num >= 90 && version_num < 100) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
} else if (version_num >= 100) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X