#include #include #include #include "cutlass/cutlass.h" #include "grouped_mm_c3x.cuh" using namespace cute; namespace { template typename Epilogue> struct sm100_fp8_config_default { static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using ArchTag = cutlass::arch::Sm100; using Cutlass3xGemm = cutlass_3x_group_gemm; }; template typename Epilogue> struct sm100_fp8_config_M64 { // M in [1,64] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using ArchTag = cutlass::arch::Sm100; using Cutlass3xGemm = cutlass_3x_group_gemm; }; template typename Epilogue> struct sm100_fp8_config_N8192 { // N in [8192, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; using TileShape = cute::Shape; using ClusterShape = cute::Shape; using ArchTag = cutlass::arch::Sm100; using Cutlass3xGemm = cutlass_3x_group_gemm; }; template void run_cutlass_moe_mm_sm100( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch) { TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, "A tensors must be of type float8_e4m3fn."); TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, "B tensors must be of type float8_e4m3fn."); using Cutlass3xGemmDefault = typename sm100_fp8_config_default< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm100_fp8_config_M64< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; uint32_t const m = a_tensors.size(0); uint32_t const n = out_tensors.size(1); if (m <= 64) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); } else if (n >= 8192) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); } else { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); } } } // namespace void dispatch_moe_mm_sm100( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch) { if (out_tensors.dtype() == torch::kBFloat16) { run_cutlass_moe_mm_sm100( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); } else { run_cutlass_moe_mm_sm100( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); } } void cutlass_moe_mm_sm100( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch) { dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); }