[Kernel] CUTLASS MoE FP8: Integrate cuda moe permute/unpermute (#23045)

Signed-off-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
shixianc
2025-08-20 07:35:26 -07:00
committed by GitHub
parent 4449235843
commit b17109beea
15 changed files with 369 additions and 121 deletions

View File

@@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
}
}
namespace {
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& atomic_buffer,
int64_t num_experts, int64_t n,
int64_t k, cudaStream_t stream,
const bool swap_ab) {
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
if (swap_ab) {
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
static_cast<int>(k));
} else {
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
static_cast<int>(k));
}
}
} // namespace
void get_cutlass_moe_mm_problem_sizes_caller(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
// Swap-AB should be disabled for FP4 path
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
atomic_buffer, num_experts, n, k, stream,
may_swap_ab);
}
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
@@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
if (may_swap_ab) {
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
k);
} else {
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
k);
}
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
atomic_buffer, num_experts, n, k, stream,
may_swap_ab);
if (blockscale_offsets.has_value()) {
// fp4 path