diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh index f255b27a1..0701cb586 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh @@ -110,6 +110,33 @@ struct cutlass_3x_gemm_fp8_blockwise { struct GemmKernel : public KernelType {}; }; +// Tile configurations for different M ranges +template +struct sm120_blockwise_fp8_config_default { + // M > 256: use 128x128x128 tile with Cooperative (Auto) schedule + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + // ScaleGranularity must match the actual quantization block size (1, 128, 128) + using Gemm = cutlass_3x_gemm_fp8_blockwise< + OutType, 1, 128, 128, TileShape, ClusterShape, + EpilogueSchedule, KernelSchedule>; +}; + +template +struct sm120_blockwise_fp8_config_M64 { + // M in [1, 256]: use 64x128x128 tile with Pingpong schedule + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + // ScaleGranularity stays (1, 128, 128) to match actual quantization data + using Gemm = cutlass_3x_gemm_fp8_blockwise< + OutType, 1, 128, 128, TileShape, ClusterShape, + EpilogueSchedule, KernelSchedule>; +}; + template void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, @@ -174,11 +201,15 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - // TODO: better heuristics - cutlass_gemm_caller_blockwise, - Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto, - cutlass::gemm::collective::KernelScheduleAuto>>( + int M = a.size(0); + if (M <= 256) { + using Gemm = typename sm120_blockwise_fp8_config_M64::Gemm; + return cutlass_gemm_caller_blockwise( + out, a, b, a_scales, b_scales); + } + // M > 256: use default 128x128x128 config with Cooperative (Auto) schedule + using Gemm = typename sm120_blockwise_fp8_config_default::Gemm; + return cutlass_gemm_caller_blockwise( out, a, b, a_scales, b_scales); }