[Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support (#10995)
Co-authored-by: Faraz Shahsavan <faraz.shahsavan@gmail.com> Co-authored-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Rahul Tuli <rahul@neuralmagic.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
163
csrc/sparse/cutlass/sparse_compressor_c3x.cu
Normal file
163
csrc/sparse/cutlass/sparse_compressor_c3x.cu
Normal file
@@ -0,0 +1,163 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include "sparse_scaled_mm_c3x.cuh"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||
template <typename ElementA_, typename ElementAcc_>
|
||||
bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
|
||||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(a.dim() == 2)
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
|
||||
TORCH_CHECK(a.stride(1) == 1)
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
|
||||
// Sparse kernel setup; this kernel is not used for matmul,
|
||||
// but just for setting up the compressor utility
|
||||
// A matrix configuration
|
||||
using ElementA = ElementA_;
|
||||
using LayoutTagA = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
// B matrix configuration
|
||||
using ElementB = ElementA;
|
||||
using LayoutTagB = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
// C/D matrix configuration
|
||||
using ElementC = float;
|
||||
using LayoutTagC = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = ElementAcc_;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using TileShapeRef = Shape<_128, _128, _64>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = typename std::conditional<
|
||||
std::is_same_v<ElementA, cutlass::float_e4m3_t>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::type;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using ProblemShape = Shape<int, int, int, int>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
|
||||
AlignmentC, ElementC, LayoutTagC, AlignmentC,
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
|
||||
LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
|
||||
using StrideE = StrideA;
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
|
||||
// The n (=1) dimension does not matter for the compressor
|
||||
typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};
|
||||
|
||||
using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;
|
||||
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
// Offline compressor kernel
|
||||
using CompressorUtility =
|
||||
cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
ProblemShape, ElementA, LayoutTagA, SparseConfig>;
|
||||
|
||||
using CompressorKernel =
|
||||
cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
ProblemShape, ElementA, LayoutTagA, SparseConfig,
|
||||
cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor =
|
||||
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
auto [M, N, K, L] = prob_shape;
|
||||
|
||||
StrideA stride_A;
|
||||
stride_A =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
|
||||
CompressorUtility compressor_utility(prob_shape, stride_A);
|
||||
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
|
||||
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
|
||||
auto a_meta_ptr = static_cast<typename Gemm::CollectiveMainloop::ElementE*>(
|
||||
a_meta.data_ptr());
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
typename Compressor::Arguments arguments{
|
||||
prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}};
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
CUTLASS_CHECK(compressor_op.can_implement(arguments));
|
||||
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(compressor_op.run());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a) {
|
||||
if (a.dtype() == torch::kBFloat16) {
|
||||
return cutlass_sparse_compress<cutlass::bfloat16_t, float>(a_nzs, a_meta,
|
||||
a);
|
||||
} else if (a.dtype() == torch::kFloat16) {
|
||||
return cutlass_sparse_compress<cutlass::half_t, float>(a_nzs, a_meta, a);
|
||||
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
return cutlass_sparse_compress<cutlass::float_e4m3_t, float>(a_nzs, a_meta,
|
||||
a);
|
||||
} else if (a.dtype() == torch::kInt8) {
|
||||
return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
Reference in New Issue
Block a user