[3/n] Migrate cutlass/scaled_mm_entry.cu torch stable ABI (#37221)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
@@ -0,0 +1,89 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
|
||||
__global__ void get_group_gemm_starts(
|
||||
int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scales_base_as_int,
|
||||
ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
int expert_id = threadIdx.x;
|
||||
|
||||
int64_t expert_offset = expert_offsets[expert_id];
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
|
||||
b_offsets[expert_id] = b_base_as_int + expert_id * k * n;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scales_offsets[expert_id] =
|
||||
a_scales_base_as_int + (per_act_token ? expert_offset : 0);
|
||||
b_scales_offsets[expert_id] =
|
||||
b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id);
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<float*>(b_scales.data_ptr()), out_tensors.size(1), \
|
||||
a_tensors.size(1), per_act_token, per_out_ch); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void run_get_group_gemm_starts(
|
||||
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
|
||||
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
|
||||
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
STD_TORCH_CHECK(a_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
// expect int64_t to avoid overflow during offset calculations
|
||||
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
|
||||
cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
|
||||
else {
|
||||
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -0,0 +1,190 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "get_group_starts.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutB_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
using LayoutD_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
using LayoutC = LayoutD;
|
||||
using LayoutC_Transpose = LayoutD_Transpose;
|
||||
|
||||
template <typename ElementAB_, typename ElementC_, typename ArchTag_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule, bool swap_ab_ = false>
|
||||
struct cutlass_3x_group_gemm {
|
||||
static constexpr bool swap_ab = swap_ab_;
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementC = void;
|
||||
using ElementD = ElementC_;
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = ArchTag_;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
|
||||
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, ElementC,
|
||||
conditional_t<swap_ab, LayoutC_Transpose*, LayoutC*>, AlignmentC,
|
||||
ElementD, conditional_t<swap_ab, LayoutD_Transpose*, LayoutD*>,
|
||||
AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
using CollectiveMainloop = conditional_t<
|
||||
swap_ab,
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementAB, LayoutB_Transpose*, AlignmentAB,
|
||||
ElementAB, LayoutA_Transpose*, AlignmentAB, ElementAccumulator,
|
||||
TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp,
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
|
||||
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
|
||||
Stages, KernelSchedule>::CollectiveOp>;
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_group_gemm_caller(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
|
||||
auto device = a_tensors.device();
|
||||
|
||||
torch::stable::Tensor a_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor out_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
|
||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
|
||||
out_tensors, a_scales, b_scales);
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideC = typename GemmKernel::InternalStrideC;
|
||||
|
||||
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<ProblemShape::UnderlyingProblemShape*>(
|
||||
problem_sizes.data_ptr());
|
||||
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args;
|
||||
if constexpr (swap_ab) {
|
||||
mainloop_args = typename GemmKernel::MainloopArguments{
|
||||
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(b_strides.data_ptr()),
|
||||
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(a_strides.data_ptr())};
|
||||
} else {
|
||||
mainloop_args = typename GemmKernel::MainloopArguments{
|
||||
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(a_strides.data_ptr()),
|
||||
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(b_strides.data_ptr())};
|
||||
}
|
||||
|
||||
// Currently, we are only able to do broadcast on either all or none a_scales
|
||||
// and on either all or none b_scales
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
swap_ab ? static_cast<const ElementAccumulator**>(
|
||||
b_scales_ptrs.data_ptr())
|
||||
: static_cast<const ElementAccumulator**>(
|
||||
a_scales_ptrs.data_ptr()),
|
||||
swap_ab ? static_cast<const ElementAccumulator**>(
|
||||
a_scales_ptrs.data_ptr())
|
||||
: static_cast<const ElementAccumulator**>(
|
||||
b_scales_ptrs.data_ptr()),
|
||||
swap_ab ? per_out_ch : per_act_token,
|
||||
swap_ab ? per_act_token : per_out_ch),
|
||||
nullptr, static_cast<StrideC*>(c_strides.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
|
||||
int device_id = a_tensors.get_device_index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
|
||||
epilogue_args, hw_info};
|
||||
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -0,0 +1,155 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_default {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_M64 {
|
||||
// M in [1,64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
true>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_N8192 {
|
||||
// N in [8192, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
STD_TORCH_CHECK(
|
||||
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 = typename sm100_fp8_config_M64<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a_tensors.size(0);
|
||||
uint32_t const n = out_tensors.size(1);
|
||||
|
||||
if (m <= 64) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else if (n >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void dispatch_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else {
|
||||
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::half_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_default {
|
||||
// M in (16, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M4 {
|
||||
// M in [1, 4]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
true>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M64 {
|
||||
// M in (4, 64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
true>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_K8192 {
|
||||
// K in [8192, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_N8192 {
|
||||
// N in [8192, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
STD_TORCH_CHECK(
|
||||
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM4 = typename sm90_fp8_config_M4<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 = typename sm90_fp8_config_M64<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a_tensors.size(0);
|
||||
uint32_t const n = out_tensors.size(1);
|
||||
uint32_t const k = a_tensors.size(1);
|
||||
|
||||
// Use swap_ab for M <= 64 by default to reduce padding
|
||||
if (m <= 4) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmM4>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else if (m <= 64) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else if (n >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else if (k >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else {
|
||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
}
|
||||
331
csrc/libtorch_stable/quantization/w8a8/cutlass/moe/moe_data.cu
Normal file
331
csrc/libtorch_stable/quantization/w8a8/cutlass/moe/moe_data.cu
Normal file
@@ -0,0 +1,331 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include "libtorch_stable/dispatch_utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
// threshold must match the dispatch logic in run_cutlass_moe_mm_sm90()
|
||||
constexpr int SWAP_AB_THRESHOLD = 64;
|
||||
|
||||
template <bool SWAP_AB>
|
||||
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length, const int n,
|
||||
const int k, const bool is_gated) {
|
||||
int expert_id = blockIdx.x;
|
||||
// For gated activations (gate + up), first GEMM output is 2*n.
|
||||
// For non-gated activations (up only), first GEMM output is n.
|
||||
int const n1 = is_gated ? 2 * n : n;
|
||||
|
||||
int occurrences = 0;
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
occurrences += (topk_ids[i] == expert_id);
|
||||
}
|
||||
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
if constexpr (!SWAP_AB) {
|
||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 1] = n1;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 1] = k;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
} else {
|
||||
problem_sizes1[expert_id * 3] = n1;
|
||||
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = k;
|
||||
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||
int32_t* atomic_buffer, const int num_experts, const bool swap_ab) {
|
||||
int32_t tot_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += swap_ab ? problem_sizes1[i * 3 + 1] : problem_sizes1[i * 3];
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_blockscale_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||
int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts,
|
||||
const bool swap_ab) {
|
||||
int32_t tot_offset = 0;
|
||||
int32_t tot_offset_round = 0;
|
||||
expert_offsets[0] = 0;
|
||||
blockscale_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
int32_t cur_offset =
|
||||
swap_ab ? problem_sizes1[i * 3 + 1] : problem_sizes1[i * 3];
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += cur_offset;
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
tot_offset_round += (cur_offset + (128 - 1)) / 128 * 128;
|
||||
blockscale_offsets[i + 1] = tot_offset_round;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer, const int topk_length,
|
||||
const int topk) {
|
||||
int const blk_expert_id = blockIdx.x;
|
||||
int const num_experts = gridDim.x;
|
||||
int32_t const num_tokens = expert_offsets[num_experts];
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
int const expert_id = topk_ids[i];
|
||||
if (expert_id == -1 && blockIdx.x == 0) {
|
||||
// output_permutation is used to re-order the moe outputs. It is
|
||||
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
|
||||
// output of the cutlass kernels and c_map is the output_permutation.
|
||||
// c2 is initialized to zeros, therefore by setting the output_permutation
|
||||
// to num_tokens, we are guaranteed to fill the moe outputs to zero
|
||||
// for "invalid" topk_ids.
|
||||
output_permutation[i] = num_tokens;
|
||||
} else if (expert_id == blk_expert_id) {
|
||||
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||
input_permutation[start] = i / topk;
|
||||
output_permutation[i] = start;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void launch_compute_problem_sizes(const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n,
|
||||
int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab,
|
||||
const bool is_gated) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
auto const* topk_ptr = topk_ids.const_data_ptr<int32_t>();
|
||||
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
|
||||
auto* atomic_ptr = atomic_buffer.mutable_data_ptr<int32_t>();
|
||||
|
||||
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k), is_gated);
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <bool SWAP_AB>
|
||||
__global__ void compute_problem_sizes_from_expert_offsets(
|
||||
const int64_t* __restrict__ expert_first_token_offset,
|
||||
int32_t* __restrict__ problem_sizes1, int32_t* __restrict__ problem_sizes2,
|
||||
const int num_experts, const int n, const int k) {
|
||||
int const expert_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (expert_id >= num_experts) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t const m64 = expert_first_token_offset[expert_id + 1] -
|
||||
expert_first_token_offset[expert_id];
|
||||
int32_t const m = static_cast<int32_t>(m64);
|
||||
|
||||
int32_t* ps1 = problem_sizes1 + expert_id * 3;
|
||||
int32_t* ps2 = problem_sizes2 + expert_id * 3;
|
||||
|
||||
if constexpr (!SWAP_AB) {
|
||||
// [M, 2*N, K]
|
||||
ps1[0] = m;
|
||||
ps1[1] = 2 * n;
|
||||
ps1[2] = k;
|
||||
// [M, K, N]
|
||||
ps2[0] = m;
|
||||
ps2[1] = k;
|
||||
ps2[2] = n;
|
||||
} else {
|
||||
// swap logical M/N in the problem shape
|
||||
// [2*N, M, K]
|
||||
ps1[0] = 2 * n;
|
||||
ps1[1] = m;
|
||||
ps1[2] = k;
|
||||
// [K, M, N]
|
||||
ps2[0] = k;
|
||||
ps2[1] = m;
|
||||
ps2[2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab) {
|
||||
STD_TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
||||
"expert_first_token_offset must be a CUDA tensor");
|
||||
STD_TORCH_CHECK(expert_first_token_offset.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
|
||||
STD_TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
||||
"problem_sizes must be CUDA tensors");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes1.scalar_type() == torch::headeronly::ScalarType::Int &&
|
||||
problem_sizes2.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"problem_sizes must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
||||
"problem_sizes must be contiguous");
|
||||
STD_TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
||||
"problem_sizes must be 2D tensors");
|
||||
STD_TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
||||
"problem_sizes second dim must be 3");
|
||||
STD_TORCH_CHECK(problem_sizes1.size(0) == problem_sizes2.size(0) &&
|
||||
problem_sizes1.size(1) == problem_sizes2.size(1),
|
||||
"problem_sizes1 and problem_sizes2 must have same shape");
|
||||
|
||||
int64_t const num_experts64 = problem_sizes1.size(0);
|
||||
STD_TORCH_CHECK(
|
||||
expert_first_token_offset.numel() == num_experts64 + 1,
|
||||
"expert_first_token_offset must have num_experts + 1 elements");
|
||||
STD_TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
||||
STD_TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX,
|
||||
"n and k must fit in int32");
|
||||
|
||||
int const num_experts = static_cast<int>(num_experts64);
|
||||
auto stream =
|
||||
get_current_cuda_stream(expert_first_token_offset.get_device_index());
|
||||
|
||||
int const threads = (num_experts < 256) ? num_experts : 256;
|
||||
int const blocks = (num_experts + threads - 1) / threads;
|
||||
|
||||
auto const* offsets_ptr = expert_first_token_offset.const_data_ptr<int64_t>();
|
||||
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
|
||||
|
||||
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes_from_expert_offsets<SwapAB>
|
||||
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
|
||||
num_experts, static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
});
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
auto device = topk_ids.device();
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
torch::stable::Tensor atomic_buffer = torch::stable::new_zeros(
|
||||
topk_ids, {num_experts}, torch::headeronly::ScalarType::Int);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
// Swap-AB should be disabled for FP4 path
|
||||
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||
|
||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||
atomic_buffer, num_experts, n, k, stream,
|
||||
may_swap_ab, is_gated);
|
||||
|
||||
if (blockscale_offsets.has_value()) {
|
||||
// fp4 path
|
||||
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
|
||||
may_swap_ab);
|
||||
} else {
|
||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
|
||||
may_swap_ab);
|
||||
}
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||
topk_ids.size(1));
|
||||
}
|
||||
|
||||
template <bool SWAP_AB>
|
||||
__global__ void compute_batched_moe_data(
|
||||
int32_t* expert_offsets, int32_t* problem_sizes1, int32_t* problem_sizes2,
|
||||
const int32_t* __restrict__ expert_num_tokens, const int padded_m,
|
||||
const int n, const int k) {
|
||||
int expert_idx = threadIdx.x;
|
||||
expert_offsets[expert_idx] = expert_idx * padded_m;
|
||||
|
||||
if constexpr (!SWAP_AB) {
|
||||
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes2[expert_idx * 3 + 1] = k;
|
||||
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||
} else {
|
||||
problem_sizes1[expert_idx * 3] = 2 * n;
|
||||
problem_sizes1[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
|
||||
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||
problem_sizes2[expert_idx * 3] = k;
|
||||
problem_sizes2[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
|
||||
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_batched_moe_mm_data_caller(
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k) {
|
||||
auto stream = get_current_cuda_stream(expert_offsets.get_device_index());
|
||||
|
||||
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
||||
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
} else {
|
||||
compute_batched_moe_data<true><<<1, num_local_experts, 0, stream>>>(
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user