[Perf] Optimize cutlass moe problem size calculation, 5.3% E2E Throughput improvement, 2.2% TTFT improvement (#31830)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -265,6 +265,11 @@ void get_cutlass_moe_mm_problem_sizes(
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
std::optional<bool> force_swap_ab = std::nullopt);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
@@ -114,22 +116,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||
const bool swap_ab) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
|
||||
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
|
||||
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
|
||||
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
|
||||
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>();
|
||||
|
||||
if (swap_ab) {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
} else {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
}
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@@ -153,6 +150,93 @@ void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
may_swap_ab);
|
||||
}
|
||||
|
||||
template <bool SWAP_AB>
|
||||
__global__ void compute_problem_sizes_from_expert_offsets(
|
||||
const int64_t* __restrict__ expert_first_token_offset,
|
||||
int32_t* __restrict__ problem_sizes1, int32_t* __restrict__ problem_sizes2,
|
||||
const int num_experts, const int n, const int k) {
|
||||
int const expert_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (expert_id >= num_experts) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t const m64 = expert_first_token_offset[expert_id + 1] -
|
||||
expert_first_token_offset[expert_id];
|
||||
int32_t const m = static_cast<int32_t>(m64);
|
||||
|
||||
int32_t* ps1 = problem_sizes1 + expert_id * 3;
|
||||
int32_t* ps2 = problem_sizes2 + expert_id * 3;
|
||||
|
||||
if constexpr (!SWAP_AB) {
|
||||
// [M, 2*N, K]
|
||||
ps1[0] = m;
|
||||
ps1[1] = 2 * n;
|
||||
ps1[2] = k;
|
||||
// [M, K, N]
|
||||
ps2[0] = m;
|
||||
ps2[1] = k;
|
||||
ps2[2] = n;
|
||||
} else {
|
||||
// swap logical M/N in the problem shape
|
||||
// [2*N, M, K]
|
||||
ps1[0] = 2 * n;
|
||||
ps1[1] = m;
|
||||
ps1[2] = k;
|
||||
// [K, M, N]
|
||||
ps2[0] = k;
|
||||
ps2[1] = m;
|
||||
ps2[2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
||||
TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
||||
"expert_first_token_offset must be a CUDA tensor");
|
||||
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64,
|
||||
"expert_first_token_offset must be int64");
|
||||
|
||||
TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
||||
"problem_sizes must be CUDA tensors");
|
||||
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 &&
|
||||
problem_sizes2.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
||||
"problem_sizes must be contiguous");
|
||||
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
||||
"problem_sizes must be 2D tensors");
|
||||
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
||||
"problem_sizes second dim must be 3");
|
||||
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(),
|
||||
"problem_sizes1 and problem_sizes2 must have same shape");
|
||||
|
||||
int64_t const num_experts64 = problem_sizes1.size(0);
|
||||
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1,
|
||||
"expert_first_token_offset must have num_experts + 1 elements");
|
||||
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
||||
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32");
|
||||
|
||||
int const num_experts = static_cast<int>(num_experts64);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(
|
||||
expert_first_token_offset.device().index());
|
||||
|
||||
int const threads = (num_experts < 256) ? num_experts : 256;
|
||||
int const blocks = (num_experts + threads - 1) / threads;
|
||||
|
||||
auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes_from_expert_offsets<SwapAB>
|
||||
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
|
||||
num_experts, static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
});
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
|
||||
@@ -83,6 +83,11 @@ void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
std::optional<bool> force_swap_ab = std::nullopt);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
@@ -322,6 +327,25 @@ void get_cutlass_moe_mm_problem_sizes(
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
|
||||
"no cutlass_scaled_mm kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
||||
@@ -487,6 +487,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
|
||||
&get_cutlass_moe_mm_problem_sizes);
|
||||
|
||||
// compute per-expert problem sizes from expert_first_token_offset
|
||||
// produced by vLLM's moe_permute kernel
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
|
||||
" Tensor expert_first_token_offset, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int n, int k, bool swap_ab) -> ()");
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets", torch::kCUDA,
|
||||
&get_cutlass_moe_mm_problem_sizes_from_expert_offsets);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
|
||||
// as an input, and computes expert_offsets (token start indices of each
|
||||
|
||||
Reference in New Issue
Block a user