[Kernel] Optimize SM120 CUTLASS blockwise FP8 GEMM (#37970)

Signed-off-by: Necofish <liuxiangyang@mail.ustc.edu.cn>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Necofish
2026-03-25 23:20:04 +08:00
committed by GitHub
parent 4a76ad12e0
commit e7221180e1

View File

@@ -110,6 +110,33 @@ struct cutlass_3x_gemm_fp8_blockwise {
struct GemmKernel : public KernelType {};
};
// Tile configurations for different M ranges
template <typename OutType>
struct sm120_blockwise_fp8_config_default {
// M > 256: use 128x128x128 tile with Cooperative (Auto) schedule
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>;
// ScaleGranularity must match the actual quantization block size (1, 128, 128)
using Gemm = cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, TileShape, ClusterShape,
EpilogueSchedule, KernelSchedule>;
};
template <typename OutType>
struct sm120_blockwise_fp8_config_M64 {
// M in [1, 256]: use 64x128x128 tile with Pingpong schedule
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>;
// ScaleGranularity stays (1, 128, 128) to match actual quantization data
using Gemm = cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, TileShape, ClusterShape,
EpilogueSchedule, KernelSchedule>;
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
@@ -174,11 +201,15 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
// TODO: better heuristics
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, Shape<_128, _128, _128>,
Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto,
cutlass::gemm::collective::KernelScheduleAuto>>(
int M = a.size(0);
if (M <= 256) {
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
return cutlass_gemm_caller_blockwise<Gemm>(
out, a, b, a_scales, b_scales);
}
// M > 256: use default 128x128x128 config with Cooperative (Auto) schedule
using Gemm = typename sm120_blockwise_fp8_config_default<OutType>::Gemm;
return cutlass_gemm_caller_blockwise<Gemm>(
out, a, b, a_scales, b_scales);
}