[WideEP] Remove pplx all2all backend (#33724)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Tyler Michael Smith
2026-02-26 17:30:10 -05:00
committed by GitHub
parent 0f2f24c8b2
commit eb19955c37
39 changed files with 107 additions and 2069 deletions

View File

@@ -263,12 +263,10 @@ void get_cutlass_moe_mm_data_caller(
}
template <bool SWAP_AB>
__global__ void compute_pplx_data(int32_t* expert_offsets,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens,
const int padded_m, const int n,
const int k) {
__global__ void compute_batched_moe_data(
int32_t* expert_offsets, int32_t* problem_sizes1, int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens, const int padded_m,
const int n, const int k) {
int expert_idx = threadIdx.x;
expert_offsets[expert_idx] = expert_idx * padded_m;
@@ -289,24 +287,22 @@ __global__ void compute_pplx_data(int32_t* expert_offsets,
}
}
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k) {
void get_cutlass_batched_moe_mm_data_caller(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
compute_pplx_data<false><<<1, num_local_experts, 0, stream>>>(
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
} else {
compute_pplx_data<true><<<1, num_local_experts, 0, stream>>>(
compute_batched_moe_data<true><<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),

View File

@@ -82,13 +82,11 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab);
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k);
void get_cutlass_batched_moe_mm_data_caller(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
#endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
@@ -319,29 +317,30 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k) {
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled get_cutlass_batched_moe_mm_data: no "
"cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num,
". Required capability: 90, 100, or 120");
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,