[Kernel] Add expert_map support to Cutlass FP8 MOE (#16861)
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com> Co-authored-by: varun sundar rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
c9acbf1141
commit
7b8a2ab76f
@@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer, const int topk_length,
|
||||
const int topk) {
|
||||
int expert_id = blockIdx.x;
|
||||
int const blk_expert_id = blockIdx.x;
|
||||
int const num_experts = gridDim.x;
|
||||
int32_t const num_tokens = expert_offsets[num_experts];
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
if (topk_ids[i] == expert_id) {
|
||||
int const expert_id = topk_ids[i];
|
||||
if (expert_id == -1 && blockIdx.x == 0) {
|
||||
// output_permutation is used to re-order the moe outputs. It is
|
||||
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
|
||||
// output of the cutlass kernels and c_map is the output_permutation.
|
||||
// c2 is initialized to zeros, therefore by setting the output_permutation
|
||||
// to num_tokens, we are guaranteed to fill the moe outputs to zero
|
||||
// for "invalid" topk_ids.
|
||||
output_permutation[i] = num_tokens;
|
||||
} else if (expert_id == blk_expert_id) {
|
||||
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||
input_permutation[start] = i / topk;
|
||||
output_permutation[i] = start;
|
||||
@@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||
|
||||
Reference in New Issue
Block a user