[feat]: add SM100 support for cutlass FP8 groupGEMM (#20447)

Signed-off-by: Duncan Moss <djm.moss@gmail.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Duncan Moss
2025-07-22 07:27:12 -07:00
committed by GitHub
parent 4fb56914c5
commit 2c8db17cfd
8 changed files with 255 additions and 32 deletions

View File

@@ -18,7 +18,6 @@ using ProblemShape =
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
using ElementAccumulator = float;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using LayoutA = cutlass::layout::RowMajor;
@@ -33,7 +32,7 @@ using LayoutD_Transpose =
using LayoutC = LayoutD;
using LayoutC_Transpose = LayoutD_Transpose;
template <typename ElementAB_, typename ElementC_,
template <typename ElementAB_, typename ElementC_, typename ArchTag_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule, bool swap_ab_ = false>
@@ -43,6 +42,7 @@ struct cutlass_3x_group_gemm {
using ElementC = void;
using ElementD = ElementC_;
using ElementAccumulator = float;
using ArchTag = ArchTag_;
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
@@ -77,7 +77,7 @@ struct cutlass_3x_group_gemm {
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
Stages, KernelSchedule>::CollectiveOp>;
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
struct GemmKernel : public KernelType {};
@@ -156,9 +156,14 @@ void cutlass_group_gemm_caller(
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides.data_ptr())};
int device_id = a_tensors.device().index();
static const cutlass::KernelHardwareInfo hw_info{
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
device_id)};
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
epilogue_args};
epilogue_args, hw_info};
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
GemmOp gemm_op;

View File

@@ -0,0 +1,140 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"
using namespace cute;
namespace {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_default {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_M64 {
// M in [1,64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm100_fp8_config_N8192 {
// N in [8192, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm100;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType>
void run_cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmM64 = typename sm100_fp8_config_M64<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
uint32_t const m = a_tensors.size(0);
uint32_t const n = out_tensors.size(1);
if (m <= 64) {
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
} // namespace
void dispatch_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
void cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
}

View File

@@ -21,10 +21,11 @@ struct sm90_fp8_config_default {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
@@ -38,10 +39,12 @@ struct sm90_fp8_config_M4 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, true>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};
template <typename InType, typename OutType,
@@ -55,10 +58,12 @@ struct sm90_fp8_config_M64 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, true>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};
template <typename InType, typename OutType,
@@ -72,10 +77,11 @@ struct sm90_fp8_config_K8192 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
@@ -89,10 +95,11 @@ struct sm90_fp8_config_N8192 {
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
using ArchTag = cutlass::arch::Sm90;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType>
@@ -112,9 +119,6 @@ void run_cutlass_moe_mm_sm90(
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<

View File

@@ -190,4 +190,4 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}
}