diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh index c31f96bf7..37846a87b 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh @@ -12,6 +12,68 @@ namespace vllm { using c3x::cutlass_gemm_caller; +// Custom wrapper to allow specifying EpilogueTile for small M +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule, typename EpilogueTile> +struct cutlass_3x_gemm_sm120_custom { + using ElementAB = ElementAB_; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; + + using ElementD = ElementD_; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = AlignmentC; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + using Epilogue = Epilogue_; + + // MMA type + using ElementAccumulator = float; + + // Epilogue types + using ElementBias = cutlass::half_t; + using ElementCompute = float; + using ElementAux = ElementD; + using LayoutAux = LayoutD; + using ElementAmax = float; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, EpilogueTile, // Use custom EpilogueTile + ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule, void>::CollectiveOp; + + using GemmKernel = enable_sm120_only, CollectiveMainloop, CollectiveEpilogue, void>>; +}; + template typename Epilogue> struct sm120_fp8_config_default { @@ -25,6 +87,54 @@ struct sm120_fp8_config_default { KernelSchedule, EpilogueSchedule>; }; +template typename Epilogue> +struct sm120_fp8_config_M64 { + static_assert(std::is_same()); + // SM120 Cooperative kernel requires Tile M >= 128. + // For M=64 tile, we use Pingpong schedule which is more flexible with small + // tiles. + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_64, _64, _128>; + // CUTLASS 3.x on SM120 currently restricts programmatic multicast (Cluster > + // 1) for certain schedules/types. Reverting to 1x1x1 to ensure compilation. + using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm120; +}; + +template typename Epilogue> +struct sm120_fp8_config_M32 { + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_32, _64, _128>; + using ClusterShape = Shape<_1, _1, _1>; + // Use custom gemm to specify EpilogueTile M=32 + using Cutlass3xGemm = + cutlass_3x_gemm_sm120_custom>; +}; + +template typename Epilogue> +struct sm120_fp8_config_M16 { + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_16, _64, _128>; + using ClusterShape = Shape<_1, _1, _1>; + // Use custom gemm to specify EpilogueTile M=16 + using Cutlass3xGemm = + cutlass_3x_gemm_sm120_custom>; +}; + template typename Epilogue, typename... EpilogueArgs> @@ -36,6 +146,28 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + int M = a.size(0); + + if (M <= 16) { + using Cutlass3xGemmM16 = + typename sm120_fp8_config_M16::Cutlass3xGemm; + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } + if (M <= 32) { + using Cutlass3xGemmM32 = + typename sm120_fp8_config_M32::Cutlass3xGemm; + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } + + if (M <= 256) { + using Cutlass3xGemmM64 = + typename sm120_fp8_config_M64::Cutlass3xGemm; + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } + using Cutlass3xGemmDefault = typename sm120_fp8_config_default::Cutlass3xGemm; @@ -64,4 +196,4 @@ void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out, } } -} // namespace vllm \ No newline at end of file +} // namespace vllm