[Perf] Optimize cutlass moe problem size calculation, 5.3% E2E Throughput improvement, 2.2% TTFT improvement (#31830)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -265,6 +265,11 @@ void get_cutlass_moe_mm_problem_sizes(
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
std::optional<bool> force_swap_ab = std::nullopt);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
@@ -114,22 +116,17 @@ inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||
const bool swap_ab) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
|
||||
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
|
||||
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
|
||||
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
|
||||
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>();
|
||||
|
||||
if (swap_ab) {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
} else {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
}
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@@ -153,6 +150,93 @@ void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
may_swap_ab);
|
||||
}
|
||||
|
||||
template <bool SWAP_AB>
|
||||
__global__ void compute_problem_sizes_from_expert_offsets(
|
||||
const int64_t* __restrict__ expert_first_token_offset,
|
||||
int32_t* __restrict__ problem_sizes1, int32_t* __restrict__ problem_sizes2,
|
||||
const int num_experts, const int n, const int k) {
|
||||
int const expert_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (expert_id >= num_experts) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t const m64 = expert_first_token_offset[expert_id + 1] -
|
||||
expert_first_token_offset[expert_id];
|
||||
int32_t const m = static_cast<int32_t>(m64);
|
||||
|
||||
int32_t* ps1 = problem_sizes1 + expert_id * 3;
|
||||
int32_t* ps2 = problem_sizes2 + expert_id * 3;
|
||||
|
||||
if constexpr (!SWAP_AB) {
|
||||
// [M, 2*N, K]
|
||||
ps1[0] = m;
|
||||
ps1[1] = 2 * n;
|
||||
ps1[2] = k;
|
||||
// [M, K, N]
|
||||
ps2[0] = m;
|
||||
ps2[1] = k;
|
||||
ps2[2] = n;
|
||||
} else {
|
||||
// swap logical M/N in the problem shape
|
||||
// [2*N, M, K]
|
||||
ps1[0] = 2 * n;
|
||||
ps1[1] = m;
|
||||
ps1[2] = k;
|
||||
// [K, M, N]
|
||||
ps2[0] = k;
|
||||
ps2[1] = m;
|
||||
ps2[2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
||||
TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
||||
"expert_first_token_offset must be a CUDA tensor");
|
||||
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64,
|
||||
"expert_first_token_offset must be int64");
|
||||
|
||||
TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
||||
"problem_sizes must be CUDA tensors");
|
||||
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 &&
|
||||
problem_sizes2.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
||||
"problem_sizes must be contiguous");
|
||||
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
||||
"problem_sizes must be 2D tensors");
|
||||
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
||||
"problem_sizes second dim must be 3");
|
||||
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(),
|
||||
"problem_sizes1 and problem_sizes2 must have same shape");
|
||||
|
||||
int64_t const num_experts64 = problem_sizes1.size(0);
|
||||
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1,
|
||||
"expert_first_token_offset must have num_experts + 1 elements");
|
||||
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
||||
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32");
|
||||
|
||||
int const num_experts = static_cast<int>(num_experts64);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(
|
||||
expert_first_token_offset.device().index());
|
||||
|
||||
int const threads = (num_experts < 256) ? num_experts : 256;
|
||||
int const blocks = (num_experts + threads - 1) / threads;
|
||||
|
||||
auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes_from_expert_offsets<SwapAB>
|
||||
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
|
||||
num_experts, static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
});
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
|
||||
@@ -83,6 +83,11 @@ void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
std::optional<bool> force_swap_ab = std::nullopt);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
@@ -322,6 +327,25 @@ void get_cutlass_moe_mm_problem_sizes(
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
|
||||
"no cutlass_scaled_mm kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
||||
@@ -487,6 +487,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
|
||||
&get_cutlass_moe_mm_problem_sizes);
|
||||
|
||||
// compute per-expert problem sizes from expert_first_token_offset
|
||||
// produced by vLLM's moe_permute kernel
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
|
||||
" Tensor expert_first_token_offset, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int n, int k, bool swap_ab) -> ()");
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets", torch::kCUDA,
|
||||
&get_cutlass_moe_mm_problem_sizes_from_expert_offsets);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
|
||||
// as an input, and computes expert_offsets (token start indices of each
|
||||
|
||||
@@ -1075,6 +1075,25 @@ def get_cutlass_moe_mm_problem_sizes(
|
||||
)
|
||||
|
||||
|
||||
def get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
expert_first_token_offset: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor,
|
||||
problem_sizes2: torch.Tensor,
|
||||
n: int,
|
||||
k: int,
|
||||
swap_ab: bool,
|
||||
):
|
||||
"""Compute per-expert (M, N, K) problem sizes from expert_first_token_offset"""
|
||||
return torch.ops._C.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
expert_first_token_offset,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
n,
|
||||
k,
|
||||
swap_ab,
|
||||
)
|
||||
|
||||
|
||||
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
|
||||
"""
|
||||
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
|
||||
|
||||
@@ -108,15 +108,7 @@ def run_cutlass_moe_fp8(
|
||||
assert global_num_experts != -1
|
||||
assert a1q_scale is not None
|
||||
|
||||
if expert_map is not None:
|
||||
"Translate info from expert_map to topk_ids"
|
||||
local_topk_ids = torch.where(
|
||||
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
|
||||
)
|
||||
else:
|
||||
local_topk_ids = topk_ids
|
||||
|
||||
topk = local_topk_ids.size(1)
|
||||
topk = topk_ids.size(1)
|
||||
local_E = w1.size(0)
|
||||
|
||||
if use_batched_format:
|
||||
@@ -164,12 +156,8 @@ def run_cutlass_moe_fp8(
|
||||
# during offset calculations
|
||||
expert_offsets = expert_offsets.to(torch.int64)
|
||||
else:
|
||||
problem_sizes1 = torch.empty(
|
||||
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes2 = torch.empty(
|
||||
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
|
||||
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
|
||||
# permuted a1q reuses workspace2
|
||||
@@ -182,11 +170,12 @@ def run_cutlass_moe_fp8(
|
||||
expert_map,
|
||||
permuted_hidden_states=a1q_perm,
|
||||
)
|
||||
expert_offsets = expert_first_token_offset[:-1]
|
||||
|
||||
ops.get_cutlass_moe_mm_problem_sizes(
|
||||
local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K
|
||||
# swap_ab is a CUTLASS grouped-GEMM optimization (M <= 64 reduces padding).
|
||||
swap_ab = a1q.size(0) <= 64
|
||||
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, swap_ab
|
||||
)
|
||||
expert_offsets = expert_first_token_offset[:-1]
|
||||
|
||||
if not per_act_token and (expert_map is not None or use_batched_format):
|
||||
# this is necessary to avoid imprecise scale calculation caused by
|
||||
@@ -240,9 +229,7 @@ def run_cutlass_moe_fp8(
|
||||
permuted_hidden_states=mm2_out,
|
||||
topk_weights=topk_weights,
|
||||
inv_permuted_idx=inv_perm,
|
||||
expert_first_token_offset=(
|
||||
expert_first_token_offset if expert_map is not None else None
|
||||
),
|
||||
expert_first_token_offset=expert_first_token_offset,
|
||||
)
|
||||
|
||||
|
||||
@@ -772,15 +759,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}"
|
||||
)
|
||||
|
||||
# Translate info from expert_map to topk_ids
|
||||
if expert_map is not None:
|
||||
local_topk_ids = torch.where(
|
||||
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
|
||||
)
|
||||
else:
|
||||
local_topk_ids = topk_ids
|
||||
|
||||
topk = local_topk_ids.size(1)
|
||||
topk = topk_ids.size(1)
|
||||
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K))
|
||||
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
|
||||
act_out = _resize_cache(workspace2, (M * topk, N))
|
||||
@@ -790,12 +769,8 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
)
|
||||
mm2_out = _resize_cache(workspace2, (M * topk, K))
|
||||
|
||||
problem_sizes1 = torch.empty(
|
||||
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes2 = torch.empty(
|
||||
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
|
||||
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
|
||||
# permuted a1q reuses workspace2
|
||||
@@ -808,18 +783,11 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
expert_map,
|
||||
permuted_hidden_states=a1q_perm,
|
||||
)
|
||||
expert_offsets = expert_first_token_offset[:-1]
|
||||
|
||||
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
|
||||
ops.get_cutlass_moe_mm_problem_sizes(
|
||||
local_topk_ids,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
global_num_experts,
|
||||
N,
|
||||
K,
|
||||
force_swap_ab=True,
|
||||
# for RS gemm SwapAB is always enabled (swap logical M, N in the problem shape).
|
||||
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, True
|
||||
)
|
||||
expert_offsets = expert_first_token_offset[:-1]
|
||||
|
||||
ops.cutlass_w4a8_moe_mm(
|
||||
mm1_out,
|
||||
@@ -866,9 +834,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
permuted_hidden_states=mm2_out,
|
||||
topk_weights=topk_weights,
|
||||
inv_permuted_idx=inv_perm,
|
||||
expert_first_token_offset=(
|
||||
expert_first_token_offset if expert_map is not None else None
|
||||
),
|
||||
expert_first_token_offset=expert_first_token_offset,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user