[Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support (#10995)

Co-authored-by: Faraz Shahsavan <faraz.shahsavan@gmail.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Rahul Tuli <rahul@neuralmagic.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
Dipika Sikka
2024-12-18 09:57:16 -05:00
committed by GitHub
parent f04e407e6b
commit 60508ffda9
30 changed files with 2365 additions and 117 deletions

View File

@@ -3,6 +3,8 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
@@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return false;
}
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
0);
int32_t version_num = major_capability * 10 + minor_capability;
return version_num;
}
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,