Files
DeepGEMM/csrc/apis/mega.hpp
Zhean Xu 891d57b4db Add various optimizations and Mega MoE benchmarks (#316)
* Merge with private repo

* Add Mega MoE Benchmark

* Minor fix

* Update

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
2026-04-24 18:41:37 +08:00

236 lines
11 KiB
C++

#pragma once
#include <functional>
#include <pybind11/functional.h>
#if DG_TENSORMAP_COMPATIBLE
#include "../jit/compiler.hpp"
#endif
#include "../jit/device_runtime.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp"
namespace deep_gemm::mega {
static int get_token_alignment_for_mega_moe() {
return layout::kLCMCandidateBlockM;
}
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
get_symm_buffer_size_for_mega_moe(
const int& num_ranks, const int& num_experts,
const int& num_max_tokens_per_rank, const int& num_topk,
const int& hidden, const int& intermediate_hidden,
const bool& use_fp8_dispatch, const std::string& activation) {
DG_HOST_ASSERT(num_experts % num_ranks == 0);
// Workspace bytes
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
// Layouts
const auto fp8_token_layout = layout::Data(hidden);
const auto bf16_token_layout = layout::Data(hidden * 2);
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
const auto fp8_sf_layout = layout::Data(hidden / 32);
const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32);
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false);
const auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
// Input buffers
const auto input_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_tokens_per_rank,
workspace.get_end_ptr());
const auto input_sf_buffer = layout::Buffer(
fp8_sf_layout, 1, num_max_tokens_per_rank,
input_token_buffer.get_end_ptr());
const auto input_topk_idx_buffer = layout::Buffer(
input_topk_idx_layout, 1, num_max_tokens_per_rank,
input_sf_buffer.get_end_ptr());
const auto input_topk_weights_buffer = layout::Buffer(
input_topk_weights_layout, 1, num_max_tokens_per_rank,
input_topk_idx_buffer.get_end_ptr());
// Buffer configs
const auto num_max_pool_tokens = static_cast<int>(workspace.num_max_pool_tokens);
int num_max_padded_sf_pool_tokens = 0;
for (int block_m: layout::kCandidateBlockM) {
num_max_padded_sf_pool_tokens = std::max(
num_max_padded_sf_pool_tokens,
layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m)
);
}
// L1 input buffer
const auto l1_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_pool_tokens,
input_topk_weights_buffer.get_end_ptr());
const auto l1_sf_buffer = layout::Buffer(
fp8_sf_layout, 1, num_max_padded_sf_pool_tokens,
l1_token_buffer.get_end_ptr());
const auto l1_topk_weights_buffer = layout::Buffer(
l1_topk_weights_layout, 1, num_max_pool_tokens,
l1_sf_buffer.get_end_ptr());
// L2 input buffer
const auto l2_token_buffer = layout::Buffer(
fp8_intermediate_token_layout, 1, num_max_pool_tokens,
l1_topk_weights_buffer.get_end_ptr());
const auto l2_sf_buffer = layout::Buffer(
fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens,
l2_token_buffer.get_end_ptr());
// Combine input buffer: BF16 tokens for cross-rank combine
const auto combine_token_buffer = layout::Buffer(
bf16_token_layout, num_topk, num_max_tokens_per_rank,
l2_sf_buffer.get_end_ptr());
// Check SF buffer requirements
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);
// Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer
// NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major
auto slice_input_buffers = [=](const torch::Tensor& buffer) {
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
{num_max_tokens_per_rank, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto x_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
{num_max_tokens_per_rank, hidden / 128},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
auto topk_idx = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kInt64).device(buffer.device()));
auto topk_weights = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
auto l1_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
{num_max_pool_tokens, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto l1_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, hidden / 128},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
{num_max_pool_tokens, intermediate_hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto l2_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, intermediate_hidden / 128},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
};
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
}
static void fp8_fp4_mega_moe(
const torch::Tensor& y,
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_tuple,
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_tuple,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const torch::Tensor& sym_buffer,
const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx,
const int& num_max_tokens_per_rank,
const int& num_experts, const int& num_topk,
const std::tuple<int, int, int>& recipe,
const std::string& activation,
const std::optional<float>& activation_clamp_opt,
const bool& fast_math
) {
const auto [l1_weights, l1_weights_sf] = l1_weights_tuple;
const auto [l2_weights, l2_weights_sf] = l2_weights_tuple;
// Config checks
const auto num_tokens = static_cast<int>(y.size(0));
const auto [rm, rn, rk] = recipe;
DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 32);
DG_HOST_ASSERT(activation == "swiglu");
// Activation checks
const auto activation_clamp =
activation_clamp_opt.value_or(std::numeric_limits<float>::infinity());
DG_HOST_ASSERT(activation_clamp >= 0);
// Tensor checks
DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K);
DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K);
const auto arch_major = device_runtime->get_arch_major();
const auto [num_experts_per_rank, intermediate_hidden_2, hidden] =
check_grouped_ab_fp8_fp4(l1_weights, cute::UMMA::Major::K, arch_major);
const auto [num_experts_per_rank_, hidden_, intermediate_hidden] =
check_grouped_ab_fp8_fp4(l2_weights, cute::UMMA::Major::K, arch_major);
DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank);
DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_);
DG_HOST_ASSERT(hidden == hidden_);
DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden);
DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous());
// Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment
constexpr int kGranMN = 1, kGranK = 32;
check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK,
num_experts_per_rank, true, false, torch::kInt);
check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK,
num_experts_per_rank, true, false, torch::kInt);
// Check stats counter
if (cumulative_local_expert_recv_stats.has_value()) {
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank);
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous());
}
// Check buffer bytes
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
const auto num_experts_ = num_experts_per_rank * num_ranks;
const auto [num_required_bytes, slice] = get_symm_buffer_size_for_mega_moe(
num_ranks, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
true, activation);
DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast<size_t>(num_required_bytes));
DG_HOST_ASSERT(num_experts == num_experts_);
// Already registered tensors
const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer);
// Dispatch into different architectures
if (arch_major == 10) {
sm100_fp8_fp4_mega_moe(y,
l1_acts, l1_acts_sf,
l2_acts, l2_acts_sf,
l1_weights, l2_weights,
l1_weights_sf, l2_weights_sf,
cumulative_local_expert_recv_stats,
sym_buffer_ptrs,
rank_idx, num_max_tokens_per_rank,
num_experts_per_rank,
num_tokens, num_topk,
hidden, intermediate_hidden,
activation_clamp, fast_math);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
// Zero the entire symmetric buffer for debug mode
// NOTES: caller must re-copy inputs into the buffer before each kernel call
if (get_env<int>("DG_COMM_KERNEL_DEBUG"))
sym_buffer.zero_();
}
static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe);
m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe);
m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe);
#endif
}
} // namespace deep_gemm::mega