Files
DeepGEMM/csrc/jit_kernels/heuristics/mega_moe.hpp
Zhean Xu 891d57b4db Add various optimizations and Mega MoE benchmarks (#316)
* Merge with private repo

* Add Mega MoE Benchmark

* Minor fix

* Update

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
2026-04-24 18:41:37 +08:00

241 lines
11 KiB
C++

#pragma once
#include <algorithm>
#include <unordered_set>
#include <deep_gemm/layout/mega_moe.cuh>
#include "../../utils/exception.hpp"
#include "../../utils/math.hpp"
#include "../../utils/system.hpp"
#include "sm100.hpp"
namespace deep_gemm {
struct MegaMoEConfig {
// Block tiling
int block_m, block_n, block_k;
int load_block_m, load_block_n;
int store_block_m;
// SF block sizes (UTCCP 128-aligned)
int sf_block_m, sf_block_n;
// Pool capacity and SF-padded token count
int num_max_pool_tokens;
int num_padded_sf_pool_tokens;
// Swizzle modes for TMA descriptors
int swizzle_acts_mode, swizzle_weights_mode;
// Number of experts to process per wave
int num_experts_per_wave;
// Pipeline stages and shared memory
int num_stages, smem_size;
// Thread layout
int num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads;
friend std::ostream& operator << (std::ostream& os, const MegaMoEConfig& config) {
os << "MegaMoEConfig("
<< "block_m=" << config.block_m << ", block_n=" << config.block_n << ", block_k=" << config.block_k
<< ", load_block_m=" << config.load_block_m << ", load_block_n=" << config.load_block_n
<< ", store_block_m=" << config.store_block_m
<< ", sf_block_m=" << config.sf_block_m << ", sf_block_n=" << config.sf_block_n
<< ", num_max_pool_tokens=" << config.num_max_pool_tokens
<< ", num_padded_sf_pool_tokens=" << config.num_padded_sf_pool_tokens
<< ", swizzle_acts_mode=" << config.swizzle_acts_mode << ", swizzle_weights_mode=" << config.swizzle_weights_mode
<< ", num_experts_per_wave=" << config.num_experts_per_wave
<< ", num_stages=" << config.num_stages << ", smem_size=" << config.smem_size
<< ", num_dispatch_threads=" << config.num_dispatch_threads
<< ", num_non_epilogue_threads=" << config.num_non_epilogue_threads
<< ", num_epilogue_threads=" << config.num_epilogue_threads << ")";
return os;
}
};
static std::tuple<int, int, int, int> get_block_config_for_mega_moe(
const int& num_ranks, const int& num_experts,
const int& num_max_tokens_per_rank, const int& num_topk,
const int& num_tokens) {
const auto& [cluster_size, block_m, store_block_m, num_epilogue_warpgroups] = [&]() -> std::tuple<int, int, int, int> {
float num_expected_tokens_per_expert = static_cast<float>(num_tokens) * num_ranks * num_topk / num_experts;
if (num_expected_tokens_per_expert <= 8.5) {
// Really small token-per-expert (e.g. RL long-tail rollout), use the smallest block_m
return {2, 16, 8, 2};
} else if (num_expected_tokens_per_expert <= 16.5) {
// Small batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 128
return {2, 32, 16, 2};
} else if (num_expected_tokens_per_expert <= 32.5) {
// Medium batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 256
return {2, 64, 32, 1};
} else if (num_expected_tokens_per_expert <= 64.5) {
// Large batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 512
return {2, 96, 16, 2};
} else if (num_expected_tokens_per_expert <= 96.5) {
// Medium batch size, Medium EP, decoding, e.g. 6/384 experts, EP16, bsz 256, or EP32, bsz128
return {2, 128, 32, 2};
} else {
// Prefill, or large EP decoding
return {2, 192, 32, 2};
}
}();
// Check whether our `block_m` lies in `kCandidateBlockM`
DG_HOST_ASSERT(std::any_of(
layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs,
[=](const auto& candidate) { return candidate == block_m; })
);
// Return configs
return {cluster_size, block_m, store_block_m, num_epilogue_warpgroups * 128};
}
static int get_num_experts_per_wave_for_mega_moe(
const int& num_experts_per_rank, const int& num_tokens, const int& num_topk,
const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) {
float expected_tokens_per_expert = static_cast<float>(num_tokens) * num_topk / num_experts_per_rank;
if (expected_tokens_per_expert < 1) {
// Most experts don't have tokens, calculate all experts at once
return num_experts_per_rank;
}
// Reduce per-expert block count by this factor since uneven routing leaves some experts with fewer tokens
constexpr int kImbalanceFactor = 2;
// Count L1 blocks per expert assuming tokens are evenly spread across experts
const int num_m_blocks = ceil_div(static_cast<int>(std::ceil(expected_tokens_per_expert)), block_m);
const int num_n_blocks = (2 * intermediate_hidden) / block_n;
const int num_l1_blocks_per_expert = num_m_blocks * num_n_blocks;
// Pick the smallest value whose total blocks (after imbalance reduction) can keep all SMs busy
int num_experts_per_wave = num_l1_blocks_per_expert > 0
? ceil_div(kImbalanceFactor * num_sms, num_l1_blocks_per_expert) : 1;
num_experts_per_wave = std::min(num_experts_per_wave, num_experts_per_rank);
// Round up to the nearest divisor of num_experts_per_rank so every wave processes the same count
while (num_experts_per_wave < num_experts_per_rank and num_experts_per_rank % num_experts_per_wave != 0)
++ num_experts_per_wave;
return num_experts_per_wave;
}
static std::pair<int, int> get_pipeline_config_for_mega_moe(
const int& smem_capacity,
const int& num_experts, const int& hidden,
const int& block_m, const int& block_n, const int& block_k, const int& store_block_m,
const int& sf_block_m, const int& sf_block_n,
const int& num_dispatch_warps, const int& num_epilogue_warps) {
constexpr int kSmemAlignment = 1024;
constexpr int kNumEpilogueStages = 2;
constexpr int kNumTMAStoreStages = 2;
// Always multicast on A
const int load_block_m = block_m / 2;
// Dispatch region
const int smem_expert_count_size = align(
num_experts * static_cast<int>(sizeof(uint32_t)), kSmemAlignment);
const int smem_send_buffers_size = align(
static_cast<int>(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()),
kSmemAlignment);
const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size;
// C/D output region: max of L1 FP8 (2 TMA stages, BLOCK_N/2 post-SwiGLU) and L2 BF16 (1 stage)
const auto num_epilogue_warpgroups = num_epilogue_warps / 4;
const int smem_cd_l1 = num_epilogue_warpgroups * store_block_m * (block_n / 2) * kNumTMAStoreStages;
const int smem_cd_l2 = num_epilogue_warpgroups * store_block_m * block_n * static_cast<int>(sizeof(nv_bfloat16));
const int smem_cd = std::max(smem_cd_l1, smem_cd_l2);
// Barriers (stage-independent): dispatch + tensor memory full/empty + combine (2 per epilogue warp)
const int smem_barriers = (num_dispatch_warps + kNumEpilogueStages * 2 + num_epilogue_warps * 2) * 8;
// Amax reduction
const int smem_amax_reduction = store_block_m * num_epilogue_warps * static_cast<int>(sizeof(float));
// Tensor memory pointer
const int smem_tmem_ptr = 4;
// SF is aligned to UTCCP 128-element granularity
const int smem_sfa_per_stage = sf_block_m * 4;
const int smem_sfb_per_stage = sf_block_n * 4;
// Per-stage: A tile + B tile + SFA tile + SFB tile + full/empty barriers
const int smem_per_stage = load_block_m * block_k + block_n * block_k + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8;
// Fixed total
const int smem_fixed = smem_dispatch_size + smem_cd + smem_amax_reduction + smem_barriers + smem_tmem_ptr;
// Select maximum num_stages
const int num_stages = (smem_capacity - smem_fixed) / smem_per_stage;
DG_HOST_ASSERT(num_stages >= 2);
return {num_stages, smem_fixed + num_stages * smem_per_stage};
}
static MegaMoEConfig get_mega_moe_config(
const int& num_ranks, const int& num_experts, const int& num_experts_per_rank,
const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk,
const int& hidden, const int& intermediate_hidden,
const int& num_padded_sf_pool_tokens) {
// Block config
const auto [cluster_size, block_m, store_block_m, num_epilogue_threads] =
get_block_config_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens);
const int block_n = 128;
const int block_k = 128;
const int load_block_m = block_m / 2;
const int load_block_n = block_n;
const auto [sf_block_m, sf_block_n] = SM100ArchSpec::get_sf_uttcp_aligned_block_sizes(block_m, block_n, MmaKind::MXFP8FP4);
const int num_max_pool_tokens = layout::get_num_max_pool_tokens(
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
// NOTES: FP8 activations and FP4 weights (unpacked to 8-bit in smem) both use 128B swizzle
const int swizzle_acts_mode = 128;
const int swizzle_weights_mode = 128;
// Waves
const int num_sms = device_runtime->get_num_sms();
const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe(
num_experts_per_rank, num_tokens, num_topk,
intermediate_hidden, block_m, block_n, num_sms);
// Thread layout
const int num_dispatch_threads = 128;
const int num_non_epilogue_threads = 128;
// Pipeline
const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe(
SM100ArchSpec::smem_capacity,
num_experts, hidden,
block_m, block_n, block_k, store_block_m,
sf_block_m, sf_block_n,
num_dispatch_threads / 32, num_epilogue_threads / 32);
const auto config = MegaMoEConfig {
block_m, block_n, block_k,
load_block_m, load_block_n, store_block_m,
sf_block_m, sf_block_n,
num_max_pool_tokens, num_padded_sf_pool_tokens,
swizzle_acts_mode, swizzle_weights_mode,
num_experts_per_wave,
num_stages, smem_size,
num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads
};
// Print configs for the first time
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
const auto key = fmt::format(
"MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})",
num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk);
static std::unordered_set<std::string> printed;
if (printed.count(key) == 0) {
std::cout << key << ": " << config << std::endl;
printed.insert(key);
}
}
return config;
}
} // namespace deep_gemm