* Merge with private repo * Add Mega MoE Benchmark * Minor fix * Update --------- Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
261 lines
9.1 KiB
Plaintext
261 lines
9.1 KiB
Plaintext
#pragma once
|
||
|
||
#include <cute/numeric/math.hpp>
|
||
|
||
#include <deep_gemm/common/math.cuh>
|
||
#include <deep_gemm/common/exception.cuh>
|
||
|
||
namespace deep_gemm::layout {
|
||
|
||
static constexpr int kNumCandidateBlockMs = 7;
|
||
static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192};
|
||
static constexpr int kMaxCandidateBlockM = 192;
|
||
static constexpr int kMinCandidateBlockM = 8;
|
||
static constexpr int kLCMCandidateBlockM = 384;
|
||
|
||
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M
|
||
template <typename T>
|
||
CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk,
|
||
T num_experts_per_rank) {
|
||
const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank;
|
||
const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank);
|
||
return math::constexpr_align(
|
||
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast<T>(kMaxCandidateBlockM) - 1),
|
||
static_cast<T>(kLCMCandidateBlockM));
|
||
}
|
||
|
||
// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M
|
||
template <typename T>
|
||
CUTLASS_HOST_DEVICE constexpr T get_num_padded_sf_pool_tokens(T num_max_pool_tokens, T block_m) {
|
||
return (num_max_pool_tokens / block_m) * math::constexpr_align(block_m, static_cast<T>(128));
|
||
}
|
||
|
||
// Per-token source metadata for combine write-back
|
||
struct TokenSrcMetadata {
|
||
uint32_t rank_idx;
|
||
uint32_t token_idx;
|
||
uint32_t topk_idx;
|
||
};
|
||
|
||
struct Workspace {
|
||
void* base;
|
||
uint32_t num_ranks, num_experts;
|
||
uint32_t num_experts_per_rank;
|
||
uint32_t num_max_tokens_per_rank;
|
||
uint32_t num_max_recv_tokens_per_expert;
|
||
|
||
// Pool capacity: all local experts share a contiguous token pool
|
||
uint32_t num_max_pool_tokens;
|
||
uint32_t num_max_pool_blocks;
|
||
|
||
// For both grid barrier and NVLink barrier
|
||
static constexpr uint64_t kNumBarrierSignalBytes = 32;
|
||
|
||
CUTLASS_HOST_DEVICE
|
||
Workspace(void* base,
|
||
const uint32_t& num_ranks,
|
||
const uint32_t& num_experts,
|
||
const uint32_t& num_max_tokens_per_rank,
|
||
const uint32_t& num_topk):
|
||
base(base),
|
||
num_ranks(num_ranks), num_experts(num_experts),
|
||
num_max_tokens_per_rank(num_max_tokens_per_rank) {
|
||
num_experts_per_rank = num_experts / num_ranks;
|
||
num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank;
|
||
num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
|
||
num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM;
|
||
}
|
||
|
||
CUTLASS_HOST_DEVICE
|
||
uint64_t get_num_bytes() const {
|
||
uint64_t num_bytes = 0;
|
||
|
||
// Barrier
|
||
num_bytes += kNumBarrierSignalBytes;
|
||
|
||
// Expert send/recv count
|
||
num_bytes += num_experts * sizeof(uint64_t) * 2;
|
||
|
||
// Expert recv count sum
|
||
num_bytes += num_experts_per_rank * sizeof(uint64_t);
|
||
|
||
// L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask)
|
||
num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t);
|
||
|
||
// L2 block arrival mask
|
||
num_bytes += num_max_pool_blocks * sizeof(uint64_t);
|
||
|
||
// Dispatch pulling source token-topk
|
||
num_bytes += num_experts_per_rank * num_ranks * num_max_recv_tokens_per_expert * sizeof(int);
|
||
|
||
// Combine push source indices
|
||
num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata);
|
||
|
||
// Align to TMA descriptor requirements
|
||
num_bytes = math::align<uint64_t>(num_bytes, 16);
|
||
return num_bytes;
|
||
}
|
||
|
||
CUTLASS_HOST_DEVICE
|
||
void* get_end_ptr() const {
|
||
return math::advance_ptr(base, get_num_bytes());
|
||
}
|
||
|
||
// Grid sync counters: `kNumBarrierSignalBytes` layout
|
||
// [ 0..15]: 4 x `uint32_t` grid sync counters
|
||
// [16..20]: `uint32_t` NVLink barrier counter
|
||
// [20..27]: 2 x `int` NVLink barrier signals (phase 0 and 1)
|
||
static constexpr uint32_t kNumMaxGridSyncCounters = 4;
|
||
|
||
template <uint32_t kIndex = 0>
|
||
CUTLASS_DEVICE
|
||
uint32_t* get_grid_sync_count_ptr() const {
|
||
DG_STATIC_ASSERT(kIndex < kNumMaxGridSyncCounters, "Grid sync index out of bounds");
|
||
return static_cast<uint32_t*>(base) + kIndex;
|
||
}
|
||
|
||
CUTLASS_DEVICE
|
||
uint32_t* get_nvl_barrier_counter_ptr() const {
|
||
return static_cast<uint32_t*>(base) + kNumMaxGridSyncCounters;
|
||
}
|
||
|
||
CUTLASS_DEVICE
|
||
int* get_nvl_barrier_signal_ptr(const uint32_t& phase) const {
|
||
// NOTES: the signal is signed, as we may minus
|
||
return math::advance_ptr<int>(base, (kNumMaxGridSyncCounters + 1) * sizeof(uint32_t) + phase * sizeof(int));
|
||
}
|
||
|
||
CUTLASS_DEVICE
|
||
uint64_t* get_expert_send_count_ptr(const uint32_t& expert_idx = 0) const {
|
||
return math::advance_ptr<uint64_t>(base, kNumBarrierSignalBytes) + expert_idx;
|
||
}
|
||
|
||
CUTLASS_DEVICE
|
||
uint64_t* get_expert_recv_count_ptr(
|
||
const uint32_t& rank_idx = 0, const uint32_t& expert_idx = 0) const {
|
||
return get_expert_send_count_ptr(num_experts) + rank_idx * num_experts_per_rank + expert_idx;
|
||
}
|
||
|
||
CUTLASS_DEVICE
|
||
uint64_t* get_expert_recv_count_sum_ptr(const uint32_t& expert_idx = 0) const {
|
||
return get_expert_send_count_ptr(num_experts * 2) + expert_idx;
|
||
}
|
||
|
||
CUTLASS_DEVICE
|
||
uint32_t* get_l1_arrival_count_ptr(const uint32_t& pool_block_idx = 0) const {
|
||
const auto base = get_expert_recv_count_sum_ptr(num_experts_per_rank);
|
||
return reinterpret_cast<uint32_t*>(base) + pool_block_idx;
|
||
}
|
||
|
||
CUTLASS_DEVICE
|
||
uint64_t* get_l2_arrival_mask_ptr(const uint32_t& pool_block_idx = 0) const {
|
||
// Pad L1 entry count to even so that the `l2_arrival_mask` is 8-byte aligned
|
||
const auto base = get_l1_arrival_count_ptr(math::align(num_max_pool_blocks, 2u));
|
||
return reinterpret_cast<uint64_t*>(base) + pool_block_idx;
|
||
}
|
||
|
||
// For dispatch pulling
|
||
CUTLASS_DEVICE
|
||
uint32_t* get_src_token_topk_idx_ptr(
|
||
const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const {
|
||
const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks);
|
||
return reinterpret_cast<uint32_t*>(base) +
|
||
expert_idx * (num_ranks * num_max_recv_tokens_per_expert) +
|
||
rank_idx * num_max_recv_tokens_per_expert + token_idx;
|
||
}
|
||
|
||
// For combine usages
|
||
CUTLASS_DEVICE
|
||
TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const {
|
||
const auto base = reinterpret_cast<TokenSrcMetadata*>(get_src_token_topk_idx_ptr(num_experts_per_rank));
|
||
return base + pool_token_idx;
|
||
}
|
||
};
|
||
|
||
struct Data {
|
||
uint32_t num_bytes;
|
||
bool require_tma_alignment;
|
||
void* base;
|
||
|
||
CUTLASS_HOST_DEVICE
|
||
constexpr explicit Data(
|
||
const uint32_t& num_bytes,
|
||
const bool& require_tma_alignment = true,
|
||
void* base = nullptr) :
|
||
num_bytes(num_bytes), require_tma_alignment(require_tma_alignment), base(base) {
|
||
DG_UNIFIED_ASSERT(num_bytes % 16 == 0 or not require_tma_alignment);
|
||
}
|
||
|
||
template <typename dtype_t = uint32_t>
|
||
CUTLASS_HOST_DEVICE constexpr dtype_t get_num_bytes() const {
|
||
return static_cast<dtype_t>(num_bytes);
|
||
}
|
||
|
||
template <typename dtype_t = void>
|
||
CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
|
||
return static_cast<dtype_t*>(base);
|
||
}
|
||
|
||
CUTLASS_HOST_DEVICE void set_base_ptr(void* ptr) {
|
||
base = ptr;
|
||
}
|
||
};
|
||
|
||
struct Buffer {
|
||
Data data_layout;
|
||
uint32_t num_ranks;
|
||
uint32_t num_max_tokens_per_rank;
|
||
|
||
void* base;
|
||
|
||
CUTLASS_HOST_DEVICE
|
||
Buffer(const Data& data_layout,
|
||
const uint32_t& num_ranks,
|
||
const uint32_t& max_num_tokens_per_rank,
|
||
void* base = nullptr) :
|
||
data_layout(data_layout),
|
||
num_ranks(num_ranks), num_max_tokens_per_rank(max_num_tokens_per_rank),
|
||
base(base) {}
|
||
|
||
CUTLASS_HOST_DEVICE
|
||
uint64_t get_num_bytes_per_rank() const {
|
||
return num_max_tokens_per_rank * data_layout.get_num_bytes<uint64_t>();
|
||
}
|
||
|
||
CUTLASS_HOST_DEVICE
|
||
uint64_t get_num_bytes() const {
|
||
return get_num_bytes_per_rank() * num_ranks;
|
||
}
|
||
|
||
template <typename dtype_t = void>
|
||
CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
|
||
return static_cast<dtype_t*>(base);
|
||
}
|
||
|
||
CUTLASS_HOST_DEVICE
|
||
void* get_end_ptr() const {
|
||
return math::advance_ptr(base, get_num_bytes());
|
||
}
|
||
|
||
CUTLASS_HOST_DEVICE
|
||
Buffer get_rank_buffer(const uint32_t& rank_idx) const {
|
||
return {
|
||
data_layout,
|
||
1, num_max_tokens_per_rank,
|
||
math::advance_ptr(base, get_num_bytes_per_rank() * rank_idx)
|
||
};
|
||
}
|
||
|
||
CUTLASS_HOST_DEVICE
|
||
Data get_data_buffer(const uint32_t& token_idx, const bool& global = false) const {
|
||
DG_DEVICE_ASSERT(num_ranks == 1 or global);
|
||
return Data(
|
||
data_layout.num_bytes,
|
||
data_layout.require_tma_alignment,
|
||
math::advance_ptr(base, data_layout.get_num_bytes<uint64_t>() * token_idx)
|
||
);
|
||
}
|
||
};
|
||
|
||
} // namespace deep_gemm::layout
|