#include #include #include "cutlass/cutlass.h" #include "cute/tensor.hpp" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/mixed_dtype_utils.hpp" // vllm includes #include #include #include #include "cutlass_extensions/torch_utils.hpp" #include "cutlass_extensions/common.hpp" #include "core/registration.h" #include "get_group_starts.cuh" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "w4a8_utils.cuh" namespace vllm::cutlass_w4a8_moe { using namespace cute; // ------------------------------------------------------------------------------------- // Static configuration shared across all instantiations // ------------------------------------------------------------------------------------- using ProblemShape = cutlass::gemm::GroupProblemShape>; // per // group using MmaType = cutlass::float_e4m3_t; using QuantType = cutlass::int4b_t; constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; static int constexpr PackFactor = 8; // 8 int4 packed into int32 // A matrix configuration using ElementA = MmaType; using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of // elements (up to 16 bytes) // B matrix configuration using ElementB = QuantType; // Element type for B matrix operand using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand constexpr int AlignmentB = 128 / cutlass::sizeof_bits< ElementB>::value; // Memory access granularity/alignment of B // matrix in units of elements (up to 16 bytes) // This example manually swaps and transposes, so keep transpose of input // layouts using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; // Need to pass a pointer type to make the 3rd dimension of Stride be _0 using StrideA = cute::remove_pointer_t>; using StrideB = cute::remove_pointer_t>; // Define the CuTe layout for reoredered quantized tensor B // LayoutAtomQuant places values that will be read by the same thread in // contiguous locations in global memory. It specifies the reordering within a // single warp's fragment using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom()); using LayoutB_Reordered = decltype(cute::tile_to_shape( LayoutAtomQuant{}, Layout>, StrideB>{})); using ElementScale = cutlass::float_e4m3_t; using LayoutScale = cutlass::layout::RowMajor; // C/D matrix configuration using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands constexpr int AlignmentC = 128 / cutlass::sizeof_bits< ElementC>::value; // Memory access granularity/alignment of C // matrix in units of elements (up to 16 bytes) // D matrix configuration using ElementD = ElementC; using LayoutD = LayoutC; constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Core kernel configurations using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that // supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based // on the tile size // per-channel and per-token scales for epilogue using ElementSChannel = float; template struct W4A8GroupedGemmKernel { using TileShape = decltype(cute::append(TileShape_MN{}, cute::Int{})); using ClusterShape = ClusterShape_MNK; // per-channel, per-token scales epilogue using ChTokScalesEpilogue = typename vllm::c3x::ScaledEpilogueArray; using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementSChannel, ElementC, typename cutlass::layout::LayoutTranspose::type*, AlignmentC, ElementD, typename cutlass::layout::LayoutTranspose::type*, AlignmentD, EpilogueSchedule, EVTCompute>::CollectiveOp; // =========================================================== MIXED INPUT // WITH SCALES // =========================================================================== // The Scale information must get paired with the operand that will be scaled. // In this example, B is scaled so we make a tuple of B's information and the // scale information. using CollectiveMainloopShuffled = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, cute::tuple>, LayoutB_Reordered*, AlignmentB, ElementA, LayoutA_Transpose*, AlignmentA, ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< ProblemShape, CollectiveMainloopShuffled, CollectiveEpilogue>; using GemmShuffled = cutlass::gemm::device::GemmUniversalAdapter; using StrideC = typename GemmKernelShuffled::InternalStrideC; using StrideD = typename GemmKernelShuffled::InternalStrideD; using StrideC_ref = cutlass::detail::TagToStrideC_t; using StrideD_ref = cutlass::detail::TagToStrideC_t; using StrideS = typename CollectiveMainloopShuffled::StrideScale; using StrideS_ref = cutlass::detail::TagToStrideB_t; // static asserts for passing in strides/layouts // pack to 2x int64 static_assert(sizeof(StrideS) == 2 * sizeof(int64_t)); // pack to 3xint32, static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0, "LayoutB_Reordered size must be divisible by 4 bytes"); static void grouped_mm( torch::Tensor& out_tensors, const torch::Tensor& a_tensors, const torch::Tensor& b_tensors, const torch::Tensor& a_scales, const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, const int64_t b_group_size, const torch::Tensor& expert_offsets, const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides, const torch::Tensor& b_strides, const torch::Tensor& c_strides, const torch::Tensor& group_scale_strides) { auto device = a_tensors.device(); auto device_id = device.index(); const at::cuda::OptionalCUDAGuard device_guard(device); auto stream = at::cuda::getCurrentCUDAStream(device_id); int num_experts = static_cast(expert_offsets.size(0)); int n = static_cast(b_tensors.size(1)); int k = static_cast(b_tensors.size(2)) * PackFactor; auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(device); torch::Tensor a_ptrs = torch::empty(num_experts, options_int); torch::Tensor b_ptrs = torch::empty(num_experts, options_int); torch::Tensor out_ptrs = torch::empty(num_experts, options_int); torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int); // get the correct offsets to pass to gemm run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, b_group_scales_ptrs, a_tensors, b_tensors, out_tensors, a_scales, b_scales, b_group_scales, b_group_size); // construct args using Args = typename GemmShuffled::Arguments; using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; Args arguments; ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes = static_cast( problem_sizes_torch.data_ptr()); ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; // SwapAB so B operands come first MainloopArguments mainloop_arguments{ static_cast(b_ptrs.data_ptr()), static_cast(b_strides.data_ptr()), static_cast(a_ptrs.data_ptr()), static_cast(a_strides.data_ptr()), static_cast**>( b_group_scales_ptrs.data_ptr()), static_cast(group_scale_strides.data_ptr()), static_cast(b_group_size)}; EpilogueArguments epilogue_arguments{ // since we are doing SwapAB the channel scales comes first, then token // scales ChTokScalesEpilogue::prepare_args( // see ScaledEpilogueArray static_cast( b_scales_ptrs.data_ptr()), // per-channel static_cast( a_scales_ptrs.data_ptr()), // per-token true, true), nullptr, // C static_cast(c_strides.data_ptr()), // C static_cast(out_ptrs.data_ptr()), // D static_cast(c_strides.data_ptr()) // D }; static const cutlass::KernelHardwareInfo hw_info{ device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( device_id)}; arguments = Args{cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_arguments, epilogue_arguments, hw_info}; // Allocate workspace size_t workspace_size = GemmShuffled::get_workspace_size(arguments); torch::Tensor workspace = torch::empty(workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device)); // Run GEMM GemmShuffled gemm; CUTLASS_CHECK(gemm.can_implement(arguments)); CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); CUTLASS_CHECK(gemm.run(stream)); } }; // ---------------------------------------------------------------------------- // Kernel instantiations and dispatch logic // ---------------------------------------------------------------------------- using Coop = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; using CoopEpi = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Kernel_TileShape_ClusterShape_Schedule using Kernel_128x16_1x1x1_Coop = W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; using Kernel_128x16_2x1x1_Coop = W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; using Kernel_256x16_1x1x1_Coop = W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; using Kernel_256x16_2x1x1_Coop = W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; using Kernel_256x32_1x1x1_Coop = W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; using Kernel_256x32_2x1x1_Coop = W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; using Kernel_256x64_1x1x1_Coop = W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; using Kernel_256x64_2x1x1_Coop = W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; using Kernel_256x128_1x1x1_Coop = W4A8GroupedGemmKernel, Shape<_1, _1, _1>, Coop, CoopEpi>; using Kernel_256x128_2x1x1_Coop = W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; using Kernel_128x256_2x1x1_Coop = W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; void mm_dispatch( torch::Tensor& out_tensors, const torch::Tensor& a_tensors, const torch::Tensor& b_tensors, const torch::Tensor& a_scales, const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, const int64_t b_group_size, const torch::Tensor& expert_offsets, const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, const torch::Tensor& b_strides, const torch::Tensor& c_strides, const torch::Tensor& group_scale_strides, const std::string& schedule) { if (schedule == "Kernel_128x16_1x1x1_Coop") { Kernel_128x16_1x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else if (schedule == "Kernel_128x16_2x1x1_Coop") { Kernel_128x16_2x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else if (schedule == "Kernel_256x16_1x1x1_Coop") { Kernel_256x16_1x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else if (schedule == "Kernel_256x16_2x1x1_Coop") { Kernel_256x16_2x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else if (schedule == "Kernel_256x32_1x1x1_Coop") { Kernel_256x32_1x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else if (schedule == "Kernel_256x32_2x1x1_Coop") { Kernel_256x32_2x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else if (schedule == "Kernel_256x64_1x1x1_Coop") { Kernel_256x64_1x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else if (schedule == "Kernel_256x64_2x1x1_Coop") { Kernel_256x64_2x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else if (schedule == "Kernel_256x128_1x1x1_Coop") { Kernel_256x128_1x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else if (schedule == "Kernel_256x128_2x1x1_Coop") { Kernel_256x128_2x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else if (schedule == "Kernel_128x256_2x1x1_Coop") { Kernel_128x256_2x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else { TORCH_CHECK(false, "cutlass_w4a8_moe_mm: unknown schedule string: ", schedule); } } void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors, const torch::Tensor& b_tensors, const torch::Tensor& a_scales, const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, const int64_t b_group_size, const torch::Tensor& expert_offsets, const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, const torch::Tensor& b_strides, const torch::Tensor& c_strides, const torch::Tensor& group_scale_strides, std::optional maybe_schedule) { // user has specified a schedule if (maybe_schedule) { mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides, *maybe_schedule); return; } // use heuristic int m_full = a_tensors.size(0); int n = b_tensors.size(1); int k = b_tensors.size(2) * PackFactor; // logical k int num_experts = b_tensors.size(0); // per-expert batch size assuming uniform distribution int m_expert = m_full / num_experts; std::string schedule; if (m_expert <= 16) { schedule = "Kernel_128x16_2x1x1_Coop"; } else if (m_expert <= 32) { schedule = "Kernel_256x32_1x1x1_Coop"; } else if (m_expert <= 64) { schedule = "Kernel_256x64_1x1x1_Coop"; } else if (m_expert <= 128) { schedule = "Kernel_256x128_2x1x1_Coop"; } else { // m_expert > 128 schedule = "Kernel_128x256_2x1x1_Coop"; } mm_dispatch(out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides, schedule); } std::tuple encode_and_reorder_int4b( torch::Tensor const& b_tensors) { TORCH_CHECK(b_tensors.dtype() == torch::kInt32); TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k) TORCH_CHECK(b_tensors.is_contiguous()); TORCH_CHECK(b_tensors.is_cuda()); int n = static_cast(b_tensors.size(1)); int k = static_cast(b_tensors.size(2)) * PackFactor; // logical k // CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0. // These misalignments cause silent OOB unless run under Compute Sanitizer. TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256"); TORCH_CHECK(n % 16 == 0, "n must be divisible by 16"); // we will store the layout to an int32 tensor; // this is the number of elements we need per layout constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t); torch::Tensor b_tensors_packed = torch::empty_like(b_tensors); int num_experts = static_cast(b_tensors.size(0)); auto b_ptr = static_cast(b_tensors.const_data_ptr()); auto b_packed_ptr = static_cast(b_tensors_packed.data_ptr()); // multiply by ull so result does not overflow int32 size_t num_int4_elems = 1ull * num_experts * n * k; bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr, num_int4_elems); TORCH_CHECK(ok, "unified_encode_int4b failed"); // construct the layout once; assumes each expert has the same layout using LayoutType = LayoutB_Reordered; std::vector layout_B_reordered_host(num_experts); auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, Int<1>{}}); auto shape_B = cute::make_shape(n, k, Int<1>{}); auto layout_B = make_layout(shape_B, stride_B); LayoutType layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B); // reorder weights for each expert for (int i = 0; i < num_experts; i++) { // since the storage type of int4b is 1 byte but one element is 4 bits // we need to adjust the offset int64_t offset = 1ull * i * n * k * cutlass::sizeof_bits::value / 8; cutlass::reorder_tensor(b_packed_ptr + offset, layout_B, layout_B_reordered); } // save the packed layout to torch tensor so we can re-use it auto cpu_opts = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); torch::Tensor layout_cpu = torch::empty({num_experts, layout_width}, cpu_opts); int32_t* layout_data = layout_cpu.data_ptr(); for (int i = 0; i < num_experts; ++i) { std::memcpy(layout_data + i * layout_width, // dst (int32*) &layout_B_reordered, // src (LayoutType*) sizeof(LayoutType)); // number of bytes } torch::Tensor packed_layout = layout_cpu.to(b_tensors.device(), /*non_blocking=*/false); return {b_tensors_packed, packed_layout}; } TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("cutlass_w4a8_moe_mm", &mm); m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b); } } // namespace vllm::cutlass_w4a8_moe /////////////////////////////////////////////////////////////////////////////////////////////////