[Kernel] Optimize SM120 CUTLASS blockwise FP8 GEMM (#37970)
Signed-off-by: Necofish <liuxiangyang@mail.ustc.edu.cn> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -110,6 +110,33 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
// Tile configurations for different M ranges
|
||||
template <typename OutType>
|
||||
struct sm120_blockwise_fp8_config_default {
|
||||
// M > 256: use 128x128x128 tile with Cooperative (Auto) schedule
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
// ScaleGranularity must match the actual quantization block size (1, 128, 128)
|
||||
using Gemm = cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, 128, 128, TileShape, ClusterShape,
|
||||
EpilogueSchedule, KernelSchedule>;
|
||||
};
|
||||
|
||||
template <typename OutType>
|
||||
struct sm120_blockwise_fp8_config_M64 {
|
||||
// M in [1, 256]: use 64x128x128 tile with Pingpong schedule
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
// ScaleGranularity stays (1, 128, 128) to match actual quantization data
|
||||
using Gemm = cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, 128, 128, TileShape, ClusterShape,
|
||||
EpilogueSchedule, KernelSchedule>;
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
@@ -174,11 +201,15 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
// TODO: better heuristics
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, 128, 128, Shape<_128, _128, _128>,
|
||||
Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto>>(
|
||||
int M = a.size(0);
|
||||
if (M <= 256) {
|
||||
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
|
||||
return cutlass_gemm_caller_blockwise<Gemm>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
// M > 256: use default 128x128x128 config with Cooperative (Auto) schedule
|
||||
using Gemm = typename sm120_blockwise_fp8_config_default<OutType>::Gemm;
|
||||
return cutlass_gemm_caller_blockwise<Gemm>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user