[Kernel] Tuned int8 kernels for Ada Lovelace (#6848)

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2024-07-29 22:24:58 -04:00
committed by GitHub
parent 61a97c32f6
commit af647fb8b3
4 changed files with 395 additions and 43 deletions

View File

@@ -4,7 +4,8 @@
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
@@ -98,25 +99,17 @@ template <template <typename, typename> typename Epilogue,
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
int8_t, cutlass::bfloat16_t, Epilogue,
TileShape, WarpShape, InstructionShape, 5>>(
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
int8_t, cutlass::half_t, Epilogue, TileShape,
WarpShape, InstructionShape, 5>>(
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
@@ -124,13 +117,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
return vllm::cutlass_gemm_sm89_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}