[WideEP] Remove pplx all2all backend (#33724)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
0f2f24c8b2
commit
eb19955c37
@@ -263,12 +263,10 @@ void get_cutlass_moe_mm_data_caller(
|
||||
}
|
||||
|
||||
template <bool SWAP_AB>
|
||||
__global__ void compute_pplx_data(int32_t* expert_offsets,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
const int32_t* __restrict__ expert_num_tokens,
|
||||
const int padded_m, const int n,
|
||||
const int k) {
|
||||
__global__ void compute_batched_moe_data(
|
||||
int32_t* expert_offsets, int32_t* problem_sizes1, int32_t* problem_sizes2,
|
||||
const int32_t* __restrict__ expert_num_tokens, const int padded_m,
|
||||
const int n, const int k) {
|
||||
int expert_idx = threadIdx.x;
|
||||
expert_offsets[expert_idx] = expert_idx * padded_m;
|
||||
|
||||
@@ -289,24 +287,22 @@ __global__ void compute_pplx_data(int32_t* expert_offsets,
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts,
|
||||
const int64_t padded_m,
|
||||
const int64_t n, const int64_t k) {
|
||||
void get_cutlass_batched_moe_mm_data_caller(
|
||||
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||
|
||||
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
||||
compute_pplx_data<false><<<1, num_local_experts, 0, stream>>>(
|
||||
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
} else {
|
||||
compute_pplx_data<true><<<1, num_local_experts, 0, stream>>>(
|
||||
compute_batched_moe_data<true><<<1, num_local_experts, 0, stream>>>(
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
|
||||
Reference in New Issue
Block a user