* Merge with private repo * Add Mega MoE Benchmark * Minor fix * Update --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
241 lines
11 KiB
C++
241 lines
11 KiB
C++
#pragma once
|
|
|
|
#include <algorithm>
|
|
#include <unordered_set>
|
|
|
|
#include <deep_gemm/layout/mega_moe.cuh>
|
|
|
|
#include "../../utils/exception.hpp"
|
|
#include "../../utils/math.hpp"
|
|
#include "../../utils/system.hpp"
|
|
#include "sm100.hpp"
|
|
|
|
namespace deep_gemm {
|
|
|
|
struct MegaMoEConfig {
|
|
// Block tiling
|
|
int block_m, block_n, block_k;
|
|
int load_block_m, load_block_n;
|
|
int store_block_m;
|
|
|
|
// SF block sizes (UTCCP 128-aligned)
|
|
int sf_block_m, sf_block_n;
|
|
|
|
// Pool capacity and SF-padded token count
|
|
int num_max_pool_tokens;
|
|
int num_padded_sf_pool_tokens;
|
|
|
|
// Swizzle modes for TMA descriptors
|
|
int swizzle_acts_mode, swizzle_weights_mode;
|
|
|
|
// Number of experts to process per wave
|
|
int num_experts_per_wave;
|
|
|
|
// Pipeline stages and shared memory
|
|
int num_stages, smem_size;
|
|
|
|
// Thread layout
|
|
int num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads;
|
|
|
|
friend std::ostream& operator << (std::ostream& os, const MegaMoEConfig& config) {
|
|
os << "MegaMoEConfig("
|
|
<< "block_m=" << config.block_m << ", block_n=" << config.block_n << ", block_k=" << config.block_k
|
|
<< ", load_block_m=" << config.load_block_m << ", load_block_n=" << config.load_block_n
|
|
<< ", store_block_m=" << config.store_block_m
|
|
<< ", sf_block_m=" << config.sf_block_m << ", sf_block_n=" << config.sf_block_n
|
|
<< ", num_max_pool_tokens=" << config.num_max_pool_tokens
|
|
<< ", num_padded_sf_pool_tokens=" << config.num_padded_sf_pool_tokens
|
|
<< ", swizzle_acts_mode=" << config.swizzle_acts_mode << ", swizzle_weights_mode=" << config.swizzle_weights_mode
|
|
<< ", num_experts_per_wave=" << config.num_experts_per_wave
|
|
<< ", num_stages=" << config.num_stages << ", smem_size=" << config.smem_size
|
|
<< ", num_dispatch_threads=" << config.num_dispatch_threads
|
|
<< ", num_non_epilogue_threads=" << config.num_non_epilogue_threads
|
|
<< ", num_epilogue_threads=" << config.num_epilogue_threads << ")";
|
|
return os;
|
|
}
|
|
};
|
|
|
|
static std::tuple<int, int, int, int> get_block_config_for_mega_moe(
|
|
const int& num_ranks, const int& num_experts,
|
|
const int& num_max_tokens_per_rank, const int& num_topk,
|
|
const int& num_tokens) {
|
|
const auto& [cluster_size, block_m, store_block_m, num_epilogue_warpgroups] = [&]() -> std::tuple<int, int, int, int> {
|
|
float num_expected_tokens_per_expert = static_cast<float>(num_tokens) * num_ranks * num_topk / num_experts;
|
|
if (num_expected_tokens_per_expert <= 8.5) {
|
|
// Really small token-per-expert (e.g. RL long-tail rollout), use the smallest block_m
|
|
return {2, 16, 8, 2};
|
|
} else if (num_expected_tokens_per_expert <= 16.5) {
|
|
// Small batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 128
|
|
return {2, 32, 16, 2};
|
|
} else if (num_expected_tokens_per_expert <= 32.5) {
|
|
// Medium batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 256
|
|
return {2, 64, 32, 1};
|
|
} else if (num_expected_tokens_per_expert <= 64.5) {
|
|
// Large batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 512
|
|
return {2, 96, 16, 2};
|
|
} else if (num_expected_tokens_per_expert <= 96.5) {
|
|
// Medium batch size, Medium EP, decoding, e.g. 6/384 experts, EP16, bsz 256, or EP32, bsz128
|
|
return {2, 128, 32, 2};
|
|
} else {
|
|
// Prefill, or large EP decoding
|
|
return {2, 192, 32, 2};
|
|
}
|
|
}();
|
|
|
|
// Check whether our `block_m` lies in `kCandidateBlockM`
|
|
DG_HOST_ASSERT(std::any_of(
|
|
layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs,
|
|
[=](const auto& candidate) { return candidate == block_m; })
|
|
);
|
|
|
|
// Return configs
|
|
return {cluster_size, block_m, store_block_m, num_epilogue_warpgroups * 128};
|
|
}
|
|
|
|
static int get_num_experts_per_wave_for_mega_moe(
|
|
const int& num_experts_per_rank, const int& num_tokens, const int& num_topk,
|
|
const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) {
|
|
|
|
float expected_tokens_per_expert = static_cast<float>(num_tokens) * num_topk / num_experts_per_rank;
|
|
if (expected_tokens_per_expert < 1) {
|
|
// Most experts don't have tokens, calculate all experts at once
|
|
return num_experts_per_rank;
|
|
}
|
|
|
|
// Reduce per-expert block count by this factor since uneven routing leaves some experts with fewer tokens
|
|
constexpr int kImbalanceFactor = 2;
|
|
|
|
// Count L1 blocks per expert assuming tokens are evenly spread across experts
|
|
const int num_m_blocks = ceil_div(static_cast<int>(std::ceil(expected_tokens_per_expert)), block_m);
|
|
const int num_n_blocks = (2 * intermediate_hidden) / block_n;
|
|
const int num_l1_blocks_per_expert = num_m_blocks * num_n_blocks;
|
|
|
|
// Pick the smallest value whose total blocks (after imbalance reduction) can keep all SMs busy
|
|
int num_experts_per_wave = num_l1_blocks_per_expert > 0
|
|
? ceil_div(kImbalanceFactor * num_sms, num_l1_blocks_per_expert) : 1;
|
|
num_experts_per_wave = std::min(num_experts_per_wave, num_experts_per_rank);
|
|
|
|
// Round up to the nearest divisor of num_experts_per_rank so every wave processes the same count
|
|
while (num_experts_per_wave < num_experts_per_rank and num_experts_per_rank % num_experts_per_wave != 0)
|
|
++ num_experts_per_wave;
|
|
|
|
return num_experts_per_wave;
|
|
}
|
|
|
|
static std::pair<int, int> get_pipeline_config_for_mega_moe(
|
|
const int& smem_capacity,
|
|
const int& num_experts, const int& hidden,
|
|
const int& block_m, const int& block_n, const int& block_k, const int& store_block_m,
|
|
const int& sf_block_m, const int& sf_block_n,
|
|
const int& num_dispatch_warps, const int& num_epilogue_warps) {
|
|
constexpr int kSmemAlignment = 1024;
|
|
constexpr int kNumEpilogueStages = 2;
|
|
constexpr int kNumTMAStoreStages = 2;
|
|
|
|
// Always multicast on A
|
|
const int load_block_m = block_m / 2;
|
|
|
|
// Dispatch region
|
|
const int smem_expert_count_size = align(
|
|
num_experts * static_cast<int>(sizeof(uint32_t)), kSmemAlignment);
|
|
const int smem_send_buffers_size = align(
|
|
static_cast<int>(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()),
|
|
kSmemAlignment);
|
|
const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size;
|
|
|
|
// C/D output region: max of L1 FP8 (2 TMA stages, BLOCK_N/2 post-SwiGLU) and L2 BF16 (1 stage)
|
|
const auto num_epilogue_warpgroups = num_epilogue_warps / 4;
|
|
const int smem_cd_l1 = num_epilogue_warpgroups * store_block_m * (block_n / 2) * kNumTMAStoreStages;
|
|
const int smem_cd_l2 = num_epilogue_warpgroups * store_block_m * block_n * static_cast<int>(sizeof(nv_bfloat16));
|
|
const int smem_cd = std::max(smem_cd_l1, smem_cd_l2);
|
|
|
|
// Barriers (stage-independent): dispatch + tensor memory full/empty + combine (2 per epilogue warp)
|
|
const int smem_barriers = (num_dispatch_warps + kNumEpilogueStages * 2 + num_epilogue_warps * 2) * 8;
|
|
|
|
// Amax reduction
|
|
const int smem_amax_reduction = store_block_m * num_epilogue_warps * static_cast<int>(sizeof(float));
|
|
|
|
// Tensor memory pointer
|
|
const int smem_tmem_ptr = 4;
|
|
|
|
// SF is aligned to UTCCP 128-element granularity
|
|
const int smem_sfa_per_stage = sf_block_m * 4;
|
|
const int smem_sfb_per_stage = sf_block_n * 4;
|
|
|
|
// Per-stage: A tile + B tile + SFA tile + SFB tile + full/empty barriers
|
|
const int smem_per_stage = load_block_m * block_k + block_n * block_k + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8;
|
|
|
|
// Fixed total
|
|
const int smem_fixed = smem_dispatch_size + smem_cd + smem_amax_reduction + smem_barriers + smem_tmem_ptr;
|
|
|
|
// Select maximum num_stages
|
|
const int num_stages = (smem_capacity - smem_fixed) / smem_per_stage;
|
|
DG_HOST_ASSERT(num_stages >= 2);
|
|
|
|
return {num_stages, smem_fixed + num_stages * smem_per_stage};
|
|
}
|
|
|
|
static MegaMoEConfig get_mega_moe_config(
|
|
const int& num_ranks, const int& num_experts, const int& num_experts_per_rank,
|
|
const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk,
|
|
const int& hidden, const int& intermediate_hidden,
|
|
const int& num_padded_sf_pool_tokens) {
|
|
// Block config
|
|
const auto [cluster_size, block_m, store_block_m, num_epilogue_threads] =
|
|
get_block_config_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens);
|
|
const int block_n = 128;
|
|
const int block_k = 128;
|
|
const int load_block_m = block_m / 2;
|
|
const int load_block_n = block_n;
|
|
const auto [sf_block_m, sf_block_n] = SM100ArchSpec::get_sf_uttcp_aligned_block_sizes(block_m, block_n, MmaKind::MXFP8FP4);
|
|
const int num_max_pool_tokens = layout::get_num_max_pool_tokens(
|
|
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
|
|
// NOTES: FP8 activations and FP4 weights (unpacked to 8-bit in smem) both use 128B swizzle
|
|
const int swizzle_acts_mode = 128;
|
|
const int swizzle_weights_mode = 128;
|
|
|
|
// Waves
|
|
const int num_sms = device_runtime->get_num_sms();
|
|
const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe(
|
|
num_experts_per_rank, num_tokens, num_topk,
|
|
intermediate_hidden, block_m, block_n, num_sms);
|
|
|
|
// Thread layout
|
|
const int num_dispatch_threads = 128;
|
|
const int num_non_epilogue_threads = 128;
|
|
|
|
// Pipeline
|
|
const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe(
|
|
SM100ArchSpec::smem_capacity,
|
|
num_experts, hidden,
|
|
block_m, block_n, block_k, store_block_m,
|
|
sf_block_m, sf_block_n,
|
|
num_dispatch_threads / 32, num_epilogue_threads / 32);
|
|
|
|
const auto config = MegaMoEConfig {
|
|
block_m, block_n, block_k,
|
|
load_block_m, load_block_n, store_block_m,
|
|
sf_block_m, sf_block_n,
|
|
num_max_pool_tokens, num_padded_sf_pool_tokens,
|
|
swizzle_acts_mode, swizzle_weights_mode,
|
|
num_experts_per_wave,
|
|
num_stages, smem_size,
|
|
num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads
|
|
};
|
|
|
|
// Print configs for the first time
|
|
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
|
|
const auto key = fmt::format(
|
|
"MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})",
|
|
num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk);
|
|
static std::unordered_set<std::string> printed;
|
|
if (printed.count(key) == 0) {
|
|
std::cout << key << ": " << config << std::endl;
|
|
printed.insert(key);
|
|
}
|
|
}
|
|
return config;
|
|
}
|
|
|
|
} // namespace deep_gemm
|