#pragma once #include #include #if DG_TENSORMAP_COMPATIBLE #include "../jit/compiler.hpp" #endif #include "../jit/device_runtime.hpp" #include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp" namespace deep_gemm::mega { static int get_token_alignment_for_mega_moe() { return layout::kLCMCandidateBlockM; } static std::tuple(const torch::Tensor&)>> get_symm_buffer_size_for_mega_moe( const int& num_ranks, const int& num_experts, const int& num_max_tokens_per_rank, const int& num_topk, const int& hidden, const int& intermediate_hidden, const bool& use_fp8_dispatch, const std::string& activation) { DG_HOST_ASSERT(num_experts % num_ranks == 0); // Workspace bytes const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk); // Layouts const auto fp8_token_layout = layout::Data(hidden); const auto bf16_token_layout = layout::Data(hidden * 2); const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden); const auto fp8_sf_layout = layout::Data(hidden / 32); const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32); const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false); const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); // Input buffers const auto input_token_buffer = layout::Buffer( fp8_token_layout, 1, num_max_tokens_per_rank, workspace.get_end_ptr()); const auto input_sf_buffer = layout::Buffer( fp8_sf_layout, 1, num_max_tokens_per_rank, input_token_buffer.get_end_ptr()); const auto input_topk_idx_buffer = layout::Buffer( input_topk_idx_layout, 1, num_max_tokens_per_rank, input_sf_buffer.get_end_ptr()); const auto input_topk_weights_buffer = layout::Buffer( input_topk_weights_layout, 1, num_max_tokens_per_rank, input_topk_idx_buffer.get_end_ptr()); // Buffer configs const auto num_max_pool_tokens = static_cast(workspace.num_max_pool_tokens); int num_max_padded_sf_pool_tokens = 0; for (int block_m: layout::kCandidateBlockM) { num_max_padded_sf_pool_tokens = std::max( num_max_padded_sf_pool_tokens, layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m) ); } // L1 input buffer const auto l1_token_buffer = layout::Buffer( fp8_token_layout, 1, num_max_pool_tokens, input_topk_weights_buffer.get_end_ptr()); const auto l1_sf_buffer = layout::Buffer( fp8_sf_layout, 1, num_max_padded_sf_pool_tokens, l1_token_buffer.get_end_ptr()); const auto l1_topk_weights_buffer = layout::Buffer( l1_topk_weights_layout, 1, num_max_pool_tokens, l1_sf_buffer.get_end_ptr()); // L2 input buffer const auto l2_token_buffer = layout::Buffer( fp8_intermediate_token_layout, 1, num_max_pool_tokens, l1_topk_weights_buffer.get_end_ptr()); const auto l2_sf_buffer = layout::Buffer( fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens, l2_token_buffer.get_end_ptr()); // Combine input buffer: BF16 tokens for cross-rank combine const auto combine_token_buffer = layout::Buffer( bf16_token_layout, num_topk, num_max_tokens_per_rank, l2_sf_buffer.get_end_ptr()); // Check SF buffer requirements DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); // Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer // NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major auto slice_input_buffers = [=](const torch::Tensor& buffer) { auto x = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_token_buffer.base)), {num_max_tokens_per_rank, hidden}, torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); auto x_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_sf_buffer.base)), {num_max_tokens_per_rank, hidden / 128}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); auto topk_idx = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_idx_buffer.base)), {num_max_tokens_per_rank, num_topk}, torch::TensorOptions().dtype(torch::kInt64).device(buffer.device())); auto topk_weights = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_weights_buffer.base)), {num_max_tokens_per_rank, num_topk}, torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device())); auto l1_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_token_buffer.base)), {num_max_pool_tokens, hidden}, torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); auto l1_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), {num_max_padded_sf_pool_tokens, hidden / 128}, {1, num_max_padded_sf_pool_tokens}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); auto l2_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), {num_max_pool_tokens, intermediate_hidden}, torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); auto l2_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), {num_max_padded_sf_pool_tokens, intermediate_hidden / 128}, {1, num_max_padded_sf_pool_tokens}, torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf); }; return {reinterpret_cast(combine_token_buffer.get_end_ptr()), slice_input_buffers}; } static void fp8_fp4_mega_moe( const torch::Tensor& y, const std::tuple& l1_weights_tuple, const std::tuple& l2_weights_tuple, const std::optional& cumulative_local_expert_recv_stats, const torch::Tensor& sym_buffer, const std::vector& sym_buffer_ptrs, const int& rank_idx, const int& num_max_tokens_per_rank, const int& num_experts, const int& num_topk, const std::tuple& recipe, const std::string& activation, const std::optional& activation_clamp_opt, const bool& fast_math ) { const auto [l1_weights, l1_weights_sf] = l1_weights_tuple; const auto [l2_weights, l2_weights_sf] = l2_weights_tuple; // Config checks const auto num_tokens = static_cast(y.size(0)); const auto [rm, rn, rk] = recipe; DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 32); DG_HOST_ASSERT(activation == "swiglu"); // Activation checks const auto activation_clamp = activation_clamp_opt.value_or(std::numeric_limits::infinity()); DG_HOST_ASSERT(activation_clamp >= 0); // Tensor checks DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K); DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K); const auto arch_major = device_runtime->get_arch_major(); const auto [num_experts_per_rank, intermediate_hidden_2, hidden] = check_grouped_ab_fp8_fp4(l1_weights, cute::UMMA::Major::K, arch_major); const auto [num_experts_per_rank_, hidden_, intermediate_hidden] = check_grouped_ab_fp8_fp4(l2_weights, cute::UMMA::Major::K, arch_major); DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank); DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_); DG_HOST_ASSERT(hidden == hidden_); DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden); DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous()); // Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment constexpr int kGranMN = 1, kGranK = 32; check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK, num_experts_per_rank, true, false, torch::kInt); check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK, num_experts_per_rank, true, false, torch::kInt); // Check stats counter if (cumulative_local_expert_recv_stats.has_value()) { DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank); DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous()); } // Check buffer bytes const auto num_ranks = static_cast(sym_buffer_ptrs.size()); const auto num_experts_ = num_experts_per_rank * num_ranks; const auto [num_required_bytes, slice] = get_symm_buffer_size_for_mega_moe( num_ranks, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden, true, activation); DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast(num_required_bytes)); DG_HOST_ASSERT(num_experts == num_experts_); // Already registered tensors const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer); // Dispatch into different architectures if (arch_major == 10) { sm100_fp8_fp4_mega_moe(y, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf, l1_weights, l2_weights, l1_weights_sf, l2_weights_sf, cumulative_local_expert_recv_stats, sym_buffer_ptrs, rank_idx, num_max_tokens_per_rank, num_experts_per_rank, num_tokens, num_topk, hidden, intermediate_hidden, activation_clamp, fast_math); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } // Zero the entire symmetric buffer for debug mode // NOTES: caller must re-copy inputs into the buffer before each kernel call if (get_env("DG_COMM_KERNEL_DEBUG")) sym_buffer.zero_(); } static void register_apis(pybind11::module_& m) { #if DG_TENSORMAP_COMPATIBLE m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe); m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe); m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe); #endif } } // namespace deep_gemm::mega