add cutlass support for blackwell fp8 gemm (#13798)
This commit is contained in:
@@ -29,6 +29,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
@@ -86,7 +91,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
// and at least SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
}
|
||||
#endif
|
||||
@@ -120,10 +125,22 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION < 12080
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
} else if (version_num >= 100) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
|
||||
Reference in New Issue
Block a user