add cutlass support for blackwell fp8 gemm (#13798)

This commit is contained in:
kushanam
2025-03-04 07:55:07 -08:00
committed by GitHub
parent b3cf368d79
commit f89978ad7c
11 changed files with 272 additions and 65 deletions

View File

@@ -16,6 +16,7 @@
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
@@ -64,22 +65,28 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementC = typename Gemm::ElementC;
using ElementD = typename Gemm::ElementD;
using GemmKernel = typename Gemm::GemmKernel;
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideA = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using StrideB = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using StrideC = typename Gemm::StrideC;
StrideA a_stride{lda, cute::Int<1>{}, 0};
StrideB b_stride{ldb, cute::Int<1>{}, 0};
StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}};
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = StrideC;
using StrideAux = StrideC;
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
auto [M, N, K, L] = prob_shape;
StrideA a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
StrideB b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
StrideC c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
StrideD d_stride =
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
StrideAux aux_stride = d_stride;
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
@@ -87,10 +94,11 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
b_stride};
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
// auto d_ptr = static_cast<ElementC*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args(
std::forward<EpilogueArgs>(epilogue_params)...),
c_ptr, c_stride, c_ptr, c_stride};
c_ptr, c_stride, c_ptr, d_stride};
cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args);