[Perf] Optimize moe_permute for CUTLASS FP8 (#32892)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -73,25 +73,40 @@ void moe_permute(
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
|
||||
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
|
||||
|
||||
// DeepGEMM: use getMIndices kernel to compute
|
||||
// 1) align_expert_first_token_offset (aligned prefix offsets)
|
||||
// 2) m_indices (expert id for each aligned row)
|
||||
// eg. expert0: 3, expert1: 5, expert2: 2 tokens respectively
|
||||
// expert_first_token_offset = [0, 3, 8, 10], align_block_size = 4
|
||||
// expert0: 3->4, expert1: 5->8, expert2: 2->4
|
||||
// align_expert_first_token_offset = [0, 4, 12, 16]
|
||||
// so m_indices = [0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2]
|
||||
torch::Tensor align_expert_first_token_offset;
|
||||
const int64_t* aligned_expert_first_token_offset_ptr = nullptr;
|
||||
if (align_block_size.has_value()) {
|
||||
align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
aligned_expert_first_token_offset_ptr =
|
||||
get_ptr<int64_t>(align_expert_first_token_offset);
|
||||
}
|
||||
|
||||
// dispatch expandInputRowsKernelLauncher
|
||||
MOE_DISPATCH(input.scalar_type(), [&] {
|
||||
expandInputRowsKernelLauncher<scalar_t>(
|
||||
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
||||
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
|
||||
n_hidden, topk, n_local_expert, align_block_size_value, stream);
|
||||
get_ptr<int64_t>(expert_first_token_offset),
|
||||
aligned_expert_first_token_offset_ptr, n_token, valid_num_ptr, n_hidden,
|
||||
topk, n_local_expert, align_block_size_value, stream);
|
||||
});
|
||||
|
||||
// get m_indices and update expert_first_token_offset with align block
|
||||
// this is only required for DeepGemm and not required for CUTLASS group gemm
|
||||
if (align_block_size.has_value()) {
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,7 +60,8 @@ void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* expert_first_token_offset,
|
||||
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream);
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ __global__ void expandInputRowsKernel(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* expert_first_token_offset,
|
||||
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
|
||||
int num_local_experts, int align_block_size) {
|
||||
// Reverse permutation map.
|
||||
@@ -18,35 +19,22 @@ __global__ void expandInputRowsKernel(
|
||||
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
int expert_id = sorted_experts[expanded_dest_row];
|
||||
|
||||
extern __shared__ int64_t smem_expert_first_token_offset[];
|
||||
if constexpr (ALIGN_BLOCK_SIZE) {
|
||||
// load g2s
|
||||
for (int idx = threadIdx.x; idx < num_local_experts + 1;
|
||||
idx += blockDim.x) {
|
||||
smem_expert_first_token_offset[idx] =
|
||||
__ldg(expert_first_token_offset + idx);
|
||||
// convert (unaligned) expanded_dest_row -> aligned expanded_dest_row.
|
||||
// aligned_expert_first_token_offset[e] provides the aligned prefix start
|
||||
// for expert e. For non-local experts we map to the end (total aligned M).
|
||||
int64_t aligned_base = 0;
|
||||
int64_t token_offset_in_expert = 0;
|
||||
if (expert_id >= num_local_experts) {
|
||||
aligned_base =
|
||||
__ldg(aligned_expert_first_token_offset + num_local_experts);
|
||||
token_offset_in_expert = 0;
|
||||
} else {
|
||||
aligned_base = __ldg(aligned_expert_first_token_offset + expert_id);
|
||||
token_offset_in_expert =
|
||||
expanded_dest_row - __ldg(expert_first_token_offset + expert_id);
|
||||
}
|
||||
__syncthreads();
|
||||
int lane_idx = threadIdx.x & 31;
|
||||
|
||||
if (lane_idx == 0) {
|
||||
// set token_offset_in_expert = 0 if this expert is not local expert
|
||||
int token_offset_in_expert =
|
||||
expert_id >= num_local_experts
|
||||
? 0
|
||||
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
|
||||
int64_t accumulate_align_offset = 0;
|
||||
#pragma unroll 1
|
||||
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
|
||||
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
|
||||
smem_expert_first_token_offset[eidx - 1];
|
||||
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
}
|
||||
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
|
||||
}
|
||||
// lane0 shuffle broadcast align_expanded_dest_row
|
||||
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
|
||||
expanded_dest_row = aligned_base + token_offset_in_expert;
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -88,7 +76,8 @@ void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* expert_first_token_offset,
|
||||
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
|
||||
int64_t const blocks = num_rows * k;
|
||||
@@ -104,14 +93,12 @@ void expandInputRowsKernelLauncher(
|
||||
bool is_align_block_size = align_block_size != -1;
|
||||
auto func = func_map[is_check_skip][is_align_block_size];
|
||||
|
||||
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
|
||||
|
||||
func<<<blocks, threads, smem_size, stream>>>(
|
||||
func<<<blocks, threads, 0, stream>>>(
|
||||
unpermuted_input, permuted_output, sorted_experts,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row, permuted_idx,
|
||||
expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
|
||||
num_local_experts, align_block_size);
|
||||
expert_first_token_offset, aligned_expert_first_token_offset, num_rows,
|
||||
num_valid_tokens_ptr, cols, k, num_local_experts, align_block_size);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
|
||||
Reference in New Issue
Block a user