Files
DeepGEMM/deep_gemm/include/deep_gemm/layout/mega_moe.cuh
Zhean Xu 891d57b4db Add various optimizations and Mega MoE benchmarks (#316)
* Merge with private repo

* Add Mega MoE Benchmark

* Minor fix

* Update

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
2026-04-24 18:41:37 +08:00

261 lines
9.1 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#pragma once
#include <cute/numeric/math.hpp>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::layout {
static constexpr int kNumCandidateBlockMs = 7;
static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192};
static constexpr int kMaxCandidateBlockM = 192;
static constexpr int kMinCandidateBlockM = 8;
static constexpr int kLCMCandidateBlockM = 384;
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M
template <typename T>
CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk,
T num_experts_per_rank) {
const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank;
const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank);
return math::constexpr_align(
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast<T>(kMaxCandidateBlockM) - 1),
static_cast<T>(kLCMCandidateBlockM));
}
// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M
template <typename T>
CUTLASS_HOST_DEVICE constexpr T get_num_padded_sf_pool_tokens(T num_max_pool_tokens, T block_m) {
return (num_max_pool_tokens / block_m) * math::constexpr_align(block_m, static_cast<T>(128));
}
// Per-token source metadata for combine write-back
struct TokenSrcMetadata {
uint32_t rank_idx;
uint32_t token_idx;
uint32_t topk_idx;
};
struct Workspace {
void* base;
uint32_t num_ranks, num_experts;
uint32_t num_experts_per_rank;
uint32_t num_max_tokens_per_rank;
uint32_t num_max_recv_tokens_per_expert;
// Pool capacity: all local experts share a contiguous token pool
uint32_t num_max_pool_tokens;
uint32_t num_max_pool_blocks;
// For both grid barrier and NVLink barrier
static constexpr uint64_t kNumBarrierSignalBytes = 32;
CUTLASS_HOST_DEVICE
Workspace(void* base,
const uint32_t& num_ranks,
const uint32_t& num_experts,
const uint32_t& num_max_tokens_per_rank,
const uint32_t& num_topk):
base(base),
num_ranks(num_ranks), num_experts(num_experts),
num_max_tokens_per_rank(num_max_tokens_per_rank) {
num_experts_per_rank = num_experts / num_ranks;
num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank;
num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM;
}
CUTLASS_HOST_DEVICE
uint64_t get_num_bytes() const {
uint64_t num_bytes = 0;
// Barrier
num_bytes += kNumBarrierSignalBytes;
// Expert send/recv count
num_bytes += num_experts * sizeof(uint64_t) * 2;
// Expert recv count sum
num_bytes += num_experts_per_rank * sizeof(uint64_t);
// L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask)
num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t);
// L2 block arrival mask
num_bytes += num_max_pool_blocks * sizeof(uint64_t);
// Dispatch pulling source token-topk
num_bytes += num_experts_per_rank * num_ranks * num_max_recv_tokens_per_expert * sizeof(int);
// Combine push source indices
num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata);
// Align to TMA descriptor requirements
num_bytes = math::align<uint64_t>(num_bytes, 16);
return num_bytes;
}
CUTLASS_HOST_DEVICE
void* get_end_ptr() const {
return math::advance_ptr(base, get_num_bytes());
}
// Grid sync counters: `kNumBarrierSignalBytes` layout
// [ 0..15]: 4 x `uint32_t` grid sync counters
// [16..20]: `uint32_t` NVLink barrier counter
// [20..27]: 2 x `int` NVLink barrier signals (phase 0 and 1)
static constexpr uint32_t kNumMaxGridSyncCounters = 4;
template <uint32_t kIndex = 0>
CUTLASS_DEVICE
uint32_t* get_grid_sync_count_ptr() const {
DG_STATIC_ASSERT(kIndex < kNumMaxGridSyncCounters, "Grid sync index out of bounds");
return static_cast<uint32_t*>(base) + kIndex;
}
CUTLASS_DEVICE
uint32_t* get_nvl_barrier_counter_ptr() const {
return static_cast<uint32_t*>(base) + kNumMaxGridSyncCounters;
}
CUTLASS_DEVICE
int* get_nvl_barrier_signal_ptr(const uint32_t& phase) const {
// NOTES: the signal is signed, as we may minus
return math::advance_ptr<int>(base, (kNumMaxGridSyncCounters + 1) * sizeof(uint32_t) + phase * sizeof(int));
}
CUTLASS_DEVICE
uint64_t* get_expert_send_count_ptr(const uint32_t& expert_idx = 0) const {
return math::advance_ptr<uint64_t>(base, kNumBarrierSignalBytes) + expert_idx;
}
CUTLASS_DEVICE
uint64_t* get_expert_recv_count_ptr(
const uint32_t& rank_idx = 0, const uint32_t& expert_idx = 0) const {
return get_expert_send_count_ptr(num_experts) + rank_idx * num_experts_per_rank + expert_idx;
}
CUTLASS_DEVICE
uint64_t* get_expert_recv_count_sum_ptr(const uint32_t& expert_idx = 0) const {
return get_expert_send_count_ptr(num_experts * 2) + expert_idx;
}
CUTLASS_DEVICE
uint32_t* get_l1_arrival_count_ptr(const uint32_t& pool_block_idx = 0) const {
const auto base = get_expert_recv_count_sum_ptr(num_experts_per_rank);
return reinterpret_cast<uint32_t*>(base) + pool_block_idx;
}
CUTLASS_DEVICE
uint64_t* get_l2_arrival_mask_ptr(const uint32_t& pool_block_idx = 0) const {
// Pad L1 entry count to even so that the `l2_arrival_mask` is 8-byte aligned
const auto base = get_l1_arrival_count_ptr(math::align(num_max_pool_blocks, 2u));
return reinterpret_cast<uint64_t*>(base) + pool_block_idx;
}
// For dispatch pulling
CUTLASS_DEVICE
uint32_t* get_src_token_topk_idx_ptr(
const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const {
const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks);
return reinterpret_cast<uint32_t*>(base) +
expert_idx * (num_ranks * num_max_recv_tokens_per_expert) +
rank_idx * num_max_recv_tokens_per_expert + token_idx;
}
// For combine usages
CUTLASS_DEVICE
TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const {
const auto base = reinterpret_cast<TokenSrcMetadata*>(get_src_token_topk_idx_ptr(num_experts_per_rank));
return base + pool_token_idx;
}
};
struct Data {
uint32_t num_bytes;
bool require_tma_alignment;
void* base;
CUTLASS_HOST_DEVICE
constexpr explicit Data(
const uint32_t& num_bytes,
const bool& require_tma_alignment = true,
void* base = nullptr) :
num_bytes(num_bytes), require_tma_alignment(require_tma_alignment), base(base) {
DG_UNIFIED_ASSERT(num_bytes % 16 == 0 or not require_tma_alignment);
}
template <typename dtype_t = uint32_t>
CUTLASS_HOST_DEVICE constexpr dtype_t get_num_bytes() const {
return static_cast<dtype_t>(num_bytes);
}
template <typename dtype_t = void>
CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
return static_cast<dtype_t*>(base);
}
CUTLASS_HOST_DEVICE void set_base_ptr(void* ptr) {
base = ptr;
}
};
struct Buffer {
Data data_layout;
uint32_t num_ranks;
uint32_t num_max_tokens_per_rank;
void* base;
CUTLASS_HOST_DEVICE
Buffer(const Data& data_layout,
const uint32_t& num_ranks,
const uint32_t& max_num_tokens_per_rank,
void* base = nullptr) :
data_layout(data_layout),
num_ranks(num_ranks), num_max_tokens_per_rank(max_num_tokens_per_rank),
base(base) {}
CUTLASS_HOST_DEVICE
uint64_t get_num_bytes_per_rank() const {
return num_max_tokens_per_rank * data_layout.get_num_bytes<uint64_t>();
}
CUTLASS_HOST_DEVICE
uint64_t get_num_bytes() const {
return get_num_bytes_per_rank() * num_ranks;
}
template <typename dtype_t = void>
CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
return static_cast<dtype_t*>(base);
}
CUTLASS_HOST_DEVICE
void* get_end_ptr() const {
return math::advance_ptr(base, get_num_bytes());
}
CUTLASS_HOST_DEVICE
Buffer get_rank_buffer(const uint32_t& rank_idx) const {
return {
data_layout,
1, num_max_tokens_per_rank,
math::advance_ptr(base, get_num_bytes_per_rank() * rank_idx)
};
}
CUTLASS_HOST_DEVICE
Data get_data_buffer(const uint32_t& token_idx, const bool& global = false) const {
DG_DEVICE_ASSERT(num_ranks == 1 or global);
return Data(
data_layout.num_bytes,
data_layout.require_tma_alignment,
math::advance_ptr(base, data_layout.get_num_bytes<uint64_t>() * token_idx)
);
}
};
} // namespace deep_gemm::layout