diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh index 14de9b1e4..a9008ce44 100644 --- a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh @@ -26,8 +26,10 @@ using namespace cute; template + class EpilogueScheduler, class MainloopScheduler, + bool swap_ab_ = false> struct cutlass_3x_gemm_fp8_blockwise { + static constexpr bool swap_ab = swap_ab_; using ElementAB = cutlass::float_e4m3_t; using ElementA = ElementAB; @@ -55,9 +57,13 @@ struct cutlass_3x_gemm_fp8_blockwise { using ElementCompute = float; using ElementBlockScale = float; - using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig< + using ScaleConfig = conditional_t; + cute::UMMA::Major::K, cute::UMMA::Major::MN>, + cutlass::detail::Sm120BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>>; // layout_SFA and layout_SFB cannot be swapped since they are deduced. using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); @@ -78,17 +84,32 @@ struct cutlass_3x_gemm_fp8_blockwise { ElementAccumulator, ElementCompute, ElementC, - LayoutC, + conditional_t, AlignmentC, ElementD, - LayoutD, + conditional_t, AlignmentD, EpilogueScheduler, DefaultOperation >::CollectiveOp; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using CollectiveMainloop = + using CollectiveMainloop = conditional_t, + AlignmentB, + ElementA, + cute::tuple, + AlignmentA, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp, typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, @@ -103,7 +124,7 @@ struct cutlass_3x_gemm_fp8_blockwise { ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, MainloopScheduler - >::CollectiveOp; + >::CollectiveOp>; // SM12x family to support both SM120 (RTX 5090) and SM121 (DGX Spark) using KernelType = enable_sm120_family struct sm120_blockwise_fp8_config_default { - // M > 256: use 128x128x128 tile with Cooperative (Auto) schedule + // use 128x128x128 tile with Cooperative (Auto) schedule using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using TileShape = Shape<_128, _128, _128>; @@ -127,8 +148,8 @@ struct sm120_blockwise_fp8_config_default { }; template -struct sm120_blockwise_fp8_config_M64 { - // M in [1, 256]: use 64x128x128 tile with Pingpong schedule +struct sm120_blockwise_fp8_config_pingpong { + // use 64x128x128 tile with Pingpong schedule using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; using TileShape = Shape<_64, _128, _128>; @@ -139,11 +160,24 @@ struct sm120_blockwise_fp8_config_M64 { EpilogueSchedule, KernelSchedule>; }; +template +struct sm120_blockwise_fp8_config_swapab { + // use 128x32x128 tile with Cooperative schedule + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _32, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using Gemm = cutlass_3x_gemm_fp8_blockwise< + OutType, 128, 1, 128, TileShape, ClusterShape, + EpilogueSchedule, KernelSchedule, true>; +}; + template void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, torch::stable::Tensor const& b_scales) { + static constexpr bool swap_ab = Gemm::swap_ab; using GemmKernel = typename Gemm::GemmKernel; using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; @@ -167,11 +201,13 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); c_stride = - cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1)); - LayoutSFA layout_SFA = + LayoutSFA layout_SFA = swap_ab ? + ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) : ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); - LayoutSFB layout_SFB = + LayoutSFB layout_SFB = swap_ab ? + ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) : ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); auto a_ptr = static_cast(a.data_ptr()); @@ -180,15 +216,24 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te auto b_scales_ptr = static_cast(b_scales.data_ptr()); typename GemmKernel::MainloopArguments mainloop_args{}; - mainloop_args.ptr_A = a_ptr; - mainloop_args.dA = a_stride; - mainloop_args.ptr_B = b_ptr; - mainloop_args.dB = b_stride; - mainloop_args.ptr_SFA = a_scales_ptr; mainloop_args.layout_SFA = layout_SFA; - mainloop_args.ptr_SFB = b_scales_ptr; mainloop_args.layout_SFB = layout_SFB; - auto prob_shape = cute::make_shape(m, n, k, 1); + if (swap_ab) { + mainloop_args.ptr_A = b_ptr; + mainloop_args.dA = b_stride; + mainloop_args.ptr_B = a_ptr; + mainloop_args.dB = a_stride; + mainloop_args.ptr_SFA = b_scales_ptr; + mainloop_args.ptr_SFB = a_scales_ptr; + } else { + mainloop_args.ptr_A = a_ptr; + mainloop_args.dA = a_stride; + mainloop_args.ptr_B = b_ptr; + mainloop_args.dB = b_stride; + mainloop_args.ptr_SFA = a_scales_ptr; + mainloop_args.ptr_SFB = b_scales_ptr; + } + auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ @@ -204,15 +249,26 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out, torch::stable::Tensor const& a_scales, torch::stable::Tensor const& b_scales) { int M = a.size(0); - if (M <= 256) { - using Gemm = typename sm120_blockwise_fp8_config_M64::Gemm; + // more heuristic tuning can be done here by checking N/K dimensions as well + bool swap_ab = (M <= 64) || (M % 4 != 0); + + if (!swap_ab) { + if (M <= 256) { + using Gemm = typename sm120_blockwise_fp8_config_pingpong::Gemm; + return cutlass_gemm_caller_blockwise( + out, a, b, a_scales, b_scales); + } + // M > 256: use default 128x128x128 config with Cooperative (Auto) schedule + using Gemm = typename sm120_blockwise_fp8_config_default::Gemm; + return cutlass_gemm_caller_blockwise( + out, a, b, a_scales, b_scales); + } else { + // Swap A/B for small M to improve performance + // Use TILE_N=32 as the minimum compatible tile size. + using Gemm = typename sm120_blockwise_fp8_config_swapab::Gemm; return cutlass_gemm_caller_blockwise( out, a, b, a_scales, b_scales); } - // M > 256: use default 128x128x128 config with Cooperative (Auto) schedule - using Gemm = typename sm120_blockwise_fp8_config_default::Gemm; - return cutlass_gemm_caller_blockwise( - out, a, b, a_scales, b_scales); } } // namespace vllm