[Kernel] Add swapAB support for SM120 CUTLASS blockwise FP8 GEMM (#38325)
This commit is contained in:
@@ -26,8 +26,10 @@ using namespace cute;
|
||||
template <class OutType, int ScaleGranularityM,
|
||||
int ScaleGranularityN, int ScaleGranularityK,
|
||||
class MmaTileShape, class ClusterShape,
|
||||
class EpilogueScheduler, class MainloopScheduler>
|
||||
class EpilogueScheduler, class MainloopScheduler,
|
||||
bool swap_ab_ = false>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
static constexpr bool swap_ab = swap_ab_;
|
||||
using ElementAB = cutlass::float_e4m3_t;
|
||||
|
||||
using ElementA = ElementAB;
|
||||
@@ -55,9 +57,13 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using ElementCompute = float;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<
|
||||
using ScaleConfig = conditional_t<swap_ab,
|
||||
cutlass::detail::Sm120BlockwiseScaleConfig<
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
|
||||
cute::UMMA::Major::K, cute::UMMA::Major::MN>,
|
||||
cutlass::detail::Sm120BlockwiseScaleConfig<
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::K>>;
|
||||
|
||||
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
@@ -78,17 +84,32 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
|
||||
AlignmentD,
|
||||
EpilogueScheduler,
|
||||
DefaultOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
||||
using CollectiveMainloop =
|
||||
using CollectiveMainloop = conditional_t<swap_ab,
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB_Transpose, LayoutSFA>,
|
||||
AlignmentB,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA_Transpose, LayoutSFB>,
|
||||
AlignmentA,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduler
|
||||
>::CollectiveOp,
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
@@ -103,7 +124,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduler
|
||||
>::CollectiveOp;
|
||||
>::CollectiveOp>;
|
||||
|
||||
// SM12x family to support both SM120 (RTX 5090) and SM121 (DGX Spark)
|
||||
using KernelType = enable_sm120_family<cutlass::gemm::kernel::GemmUniversal<
|
||||
@@ -115,7 +136,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
// Tile configurations for different M ranges
|
||||
template <typename OutType>
|
||||
struct sm120_blockwise_fp8_config_default {
|
||||
// M > 256: use 128x128x128 tile with Cooperative (Auto) schedule
|
||||
// use 128x128x128 tile with Cooperative (Auto) schedule
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
@@ -127,8 +148,8 @@ struct sm120_blockwise_fp8_config_default {
|
||||
};
|
||||
|
||||
template <typename OutType>
|
||||
struct sm120_blockwise_fp8_config_M64 {
|
||||
// M in [1, 256]: use 64x128x128 tile with Pingpong schedule
|
||||
struct sm120_blockwise_fp8_config_pingpong {
|
||||
// use 64x128x128 tile with Pingpong schedule
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
@@ -139,11 +160,24 @@ struct sm120_blockwise_fp8_config_M64 {
|
||||
EpilogueSchedule, KernelSchedule>;
|
||||
};
|
||||
|
||||
template <typename OutType>
|
||||
struct sm120_blockwise_fp8_config_swapab {
|
||||
// use 128x32x128 tile with Cooperative schedule
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _32, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Gemm = cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 128, 1, 128, TileShape, ClusterShape,
|
||||
EpilogueSchedule, KernelSchedule, true>;
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
@@ -167,11 +201,13 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
|
||||
b_stride =
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
c_stride =
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||
cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1));
|
||||
|
||||
LayoutSFA layout_SFA =
|
||||
LayoutSFA layout_SFA = swap_ab ?
|
||||
ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) :
|
||||
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||
LayoutSFB layout_SFB =
|
||||
LayoutSFB layout_SFB = swap_ab ?
|
||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
|
||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||
@@ -180,15 +216,24 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
|
||||
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{};
|
||||
mainloop_args.layout_SFA = layout_SFA;
|
||||
mainloop_args.layout_SFB = layout_SFB;
|
||||
if (swap_ab) {
|
||||
mainloop_args.ptr_A = b_ptr;
|
||||
mainloop_args.dA = b_stride;
|
||||
mainloop_args.ptr_B = a_ptr;
|
||||
mainloop_args.dB = a_stride;
|
||||
mainloop_args.ptr_SFA = b_scales_ptr;
|
||||
mainloop_args.ptr_SFB = a_scales_ptr;
|
||||
} else {
|
||||
mainloop_args.ptr_A = a_ptr;
|
||||
mainloop_args.dA = a_stride;
|
||||
mainloop_args.ptr_B = b_ptr;
|
||||
mainloop_args.dB = b_stride;
|
||||
mainloop_args.ptr_SFA = a_scales_ptr;
|
||||
mainloop_args.layout_SFA = layout_SFA;
|
||||
mainloop_args.ptr_SFB = b_scales_ptr;
|
||||
mainloop_args.layout_SFB = layout_SFB;
|
||||
auto prob_shape = cute::make_shape(m, n, k, 1);
|
||||
}
|
||||
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
@@ -204,8 +249,12 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
int M = a.size(0);
|
||||
// more heuristic tuning can be done here by checking N/K dimensions as well
|
||||
bool swap_ab = (M <= 64) || (M % 4 != 0);
|
||||
|
||||
if (!swap_ab) {
|
||||
if (M <= 256) {
|
||||
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
|
||||
using Gemm = typename sm120_blockwise_fp8_config_pingpong<OutType>::Gemm;
|
||||
return cutlass_gemm_caller_blockwise<Gemm>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
@@ -213,6 +262,13 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||
using Gemm = typename sm120_blockwise_fp8_config_default<OutType>::Gemm;
|
||||
return cutlass_gemm_caller_blockwise<Gemm>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
// Swap A/B for small M to improve performance
|
||||
// Use TILE_N=32 as the minimum compatible tile size.
|
||||
using Gemm = typename sm120_blockwise_fp8_config_swapab<OutType>::Gemm;
|
||||
return cutlass_gemm_caller_blockwise<Gemm>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
Reference in New Issue
Block a user