[Perf] Optimize cutlass moe problem size calculation, 5.3% E2E Throughput improvement, 2.2% TTFT improvement (#31830)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
@@ -114,22 +116,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||
const bool swap_ab) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
|
||||
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
|
||||
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
|
||||
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
|
||||
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>();
|
||||
|
||||
if (swap_ab) {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
} else {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
}
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@@ -153,6 +150,93 @@ void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
may_swap_ab);
|
||||
}
|
||||
|
||||
template <bool SWAP_AB>
|
||||
__global__ void compute_problem_sizes_from_expert_offsets(
|
||||
const int64_t* __restrict__ expert_first_token_offset,
|
||||
int32_t* __restrict__ problem_sizes1, int32_t* __restrict__ problem_sizes2,
|
||||
const int num_experts, const int n, const int k) {
|
||||
int const expert_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (expert_id >= num_experts) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t const m64 = expert_first_token_offset[expert_id + 1] -
|
||||
expert_first_token_offset[expert_id];
|
||||
int32_t const m = static_cast<int32_t>(m64);
|
||||
|
||||
int32_t* ps1 = problem_sizes1 + expert_id * 3;
|
||||
int32_t* ps2 = problem_sizes2 + expert_id * 3;
|
||||
|
||||
if constexpr (!SWAP_AB) {
|
||||
// [M, 2*N, K]
|
||||
ps1[0] = m;
|
||||
ps1[1] = 2 * n;
|
||||
ps1[2] = k;
|
||||
// [M, K, N]
|
||||
ps2[0] = m;
|
||||
ps2[1] = k;
|
||||
ps2[2] = n;
|
||||
} else {
|
||||
// swap logical M/N in the problem shape
|
||||
// [2*N, M, K]
|
||||
ps1[0] = 2 * n;
|
||||
ps1[1] = m;
|
||||
ps1[2] = k;
|
||||
// [K, M, N]
|
||||
ps2[0] = k;
|
||||
ps2[1] = m;
|
||||
ps2[2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
||||
TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
||||
"expert_first_token_offset must be a CUDA tensor");
|
||||
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64,
|
||||
"expert_first_token_offset must be int64");
|
||||
|
||||
TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
||||
"problem_sizes must be CUDA tensors");
|
||||
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 &&
|
||||
problem_sizes2.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
||||
"problem_sizes must be contiguous");
|
||||
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
||||
"problem_sizes must be 2D tensors");
|
||||
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
||||
"problem_sizes second dim must be 3");
|
||||
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(),
|
||||
"problem_sizes1 and problem_sizes2 must have same shape");
|
||||
|
||||
int64_t const num_experts64 = problem_sizes1.size(0);
|
||||
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1,
|
||||
"expert_first_token_offset must have num_experts + 1 elements");
|
||||
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
||||
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32");
|
||||
|
||||
int const num_experts = static_cast<int>(num_experts64);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(
|
||||
expert_first_token_offset.device().index());
|
||||
|
||||
int const threads = (num_experts < 256) ? num_experts : 256;
|
||||
int const blocks = (num_experts + threads - 1) / threads;
|
||||
|
||||
auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes_from_expert_offsets<SwapAB>
|
||||
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
|
||||
num_experts, static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
});
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
|
||||
Reference in New Issue
Block a user