[Kernel] Add swapAB support for SM120 CUTLASS blockwise FP8 GEMM (#38325)

This commit is contained in:
Necofish
2026-04-03 21:49:59 +08:00
committed by GitHub
parent cb10b7e80b
commit 580090db6b

View File

@@ -26,8 +26,10 @@ using namespace cute;
template <class OutType, int ScaleGranularityM,
int ScaleGranularityN, int ScaleGranularityK,
class MmaTileShape, class ClusterShape,
class EpilogueScheduler, class MainloopScheduler>
class EpilogueScheduler, class MainloopScheduler,
bool swap_ab_ = false>
struct cutlass_3x_gemm_fp8_blockwise {
static constexpr bool swap_ab = swap_ab_;
using ElementAB = cutlass::float_e4m3_t;
using ElementA = ElementAB;
@@ -55,9 +57,13 @@ struct cutlass_3x_gemm_fp8_blockwise {
using ElementCompute = float;
using ElementBlockScale = float;
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<
using ScaleConfig = conditional_t<swap_ab,
cutlass::detail::Sm120BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
cute::UMMA::Major::K, cute::UMMA::Major::MN>,
cutlass::detail::Sm120BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::MN, cute::UMMA::Major::K>>;
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
@@ -78,17 +84,32 @@ struct cutlass_3x_gemm_fp8_blockwise {
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>,
AlignmentC,
ElementD,
LayoutD,
conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
AlignmentD,
EpilogueScheduler,
DefaultOperation
>::CollectiveOp;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using CollectiveMainloop =
using CollectiveMainloop = conditional_t<swap_ab,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementB,
cute::tuple<LayoutB_Transpose, LayoutSFA>,
AlignmentB,
ElementA,
cute::tuple<LayoutA_Transpose, LayoutSFB>,
AlignmentA,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
@@ -103,7 +124,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp;
>::CollectiveOp>;
// SM12x family to support both SM120 (RTX 5090) and SM121 (DGX Spark)
using KernelType = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
@@ -115,7 +136,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
// Tile configurations for different M ranges
template <typename OutType>
struct sm120_blockwise_fp8_config_default {
// M > 256: use 128x128x128 tile with Cooperative (Auto) schedule
// use 128x128x128 tile with Cooperative (Auto) schedule
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _128>;
@@ -127,8 +148,8 @@ struct sm120_blockwise_fp8_config_default {
};
template <typename OutType>
struct sm120_blockwise_fp8_config_M64 {
// M in [1, 256]: use 64x128x128 tile with Pingpong schedule
struct sm120_blockwise_fp8_config_pingpong {
// use 64x128x128 tile with Pingpong schedule
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_64, _128, _128>;
@@ -139,11 +160,24 @@ struct sm120_blockwise_fp8_config_M64 {
EpilogueSchedule, KernelSchedule>;
};
template <typename OutType>
struct sm120_blockwise_fp8_config_swapab {
// use 128x32x128 tile with Cooperative schedule
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _32, _128>;
using ClusterShape = Shape<_1, _1, _1>;
using Gemm = cutlass_3x_gemm_fp8_blockwise<
OutType, 128, 1, 128, TileShape, ClusterShape,
EpilogueSchedule, KernelSchedule, true>;
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
static constexpr bool swap_ab = Gemm::swap_ab;
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -167,11 +201,13 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));
LayoutSFA layout_SFA =
LayoutSFA layout_SFA = swap_ab ?
ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) :
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB =
LayoutSFB layout_SFB = swap_ab ?
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
@@ -180,15 +216,24 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{};
mainloop_args.ptr_A = a_ptr;
mainloop_args.dA = a_stride;
mainloop_args.ptr_B = b_ptr;
mainloop_args.dB = b_stride;
mainloop_args.ptr_SFA = a_scales_ptr;
mainloop_args.layout_SFA = layout_SFA;
mainloop_args.ptr_SFB = b_scales_ptr;
mainloop_args.layout_SFB = layout_SFB;
auto prob_shape = cute::make_shape(m, n, k, 1);
if (swap_ab) {
mainloop_args.ptr_A = b_ptr;
mainloop_args.dA = b_stride;
mainloop_args.ptr_B = a_ptr;
mainloop_args.dB = a_stride;
mainloop_args.ptr_SFA = b_scales_ptr;
mainloop_args.ptr_SFB = a_scales_ptr;
} else {
mainloop_args.ptr_A = a_ptr;
mainloop_args.dA = a_stride;
mainloop_args.ptr_B = b_ptr;
mainloop_args.dB = b_stride;
mainloop_args.ptr_SFA = a_scales_ptr;
mainloop_args.ptr_SFB = b_scales_ptr;
}
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
@@ -204,15 +249,26 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
int M = a.size(0);
if (M <= 256) {
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
// more heuristic tuning can be done here by checking N/K dimensions as well
bool swap_ab = (M <= 64) || (M % 4 != 0);
if (!swap_ab) {
if (M <= 256) {
using Gemm = typename sm120_blockwise_fp8_config_pingpong<OutType>::Gemm;
return cutlass_gemm_caller_blockwise<Gemm>(
out, a, b, a_scales, b_scales);
}
// M > 256: use default 128x128x128 config with Cooperative (Auto) schedule
using Gemm = typename sm120_blockwise_fp8_config_default<OutType>::Gemm;
return cutlass_gemm_caller_blockwise<Gemm>(
out, a, b, a_scales, b_scales);
} else {
// Swap A/B for small M to improve performance
// Use TILE_N=32 as the minimum compatible tile size.
using Gemm = typename sm120_blockwise_fp8_config_swapab<OutType>::Gemm;
return cutlass_gemm_caller_blockwise<Gemm>(
out, a, b, a_scales, b_scales);
}
// M > 256: use default 128x128x128 config with Cooperative (Auto) schedule
using Gemm = typename sm120_blockwise_fp8_config_default<OutType>::Gemm;
return cutlass_gemm_caller_blockwise<Gemm>(
out, a, b, a_scales, b_scales);
}
} // namespace vllm