- New CUDA kernel: sm100_fp8_nvfp4_mega_moe_impl - kGranK=16 (NVFP4 group_size=16, vs MXFP4's 32) - kind::mxf4nvf4.block_scale.scale_vec::4X PTX instruction - float_ue4m3_t scale factor type in instruction descriptor - SF layout: scale_vec::4X (4 TMEM sub-columns per UMMA atom) - UTCCP column stride: i*8 (vs MXFP4's i*4) for 4X layout - L1 epilogue: UE4M3 activation scales (float→cutlass::float_e4m3_t) - SF loading: kNumSFUint32 = kHidden/64 (4 UE4M3 per int32) - New PTX wrappers: SM100_MMA_MXF4NVF4_2x1SM_SS, SM100_MMA_MXF4NVF4_SS - Python API: - fp8_nvfp4_mega_moe() with recipe=(1,1,16) - transform_nvfp4_weights_for_mega_moe() for UE4M3→int32 UTCCP packing - _pack_nvfp4_sf_for_utccp() helper - C++ bindings: - mega_nvfp4.hpp with NVFP4-specific SymmBuffer (SF stride K/16) - JIT kernel header with kGranK=16 TMA descriptors - Registered in python_api.cpp NOTE: Both SFA and SFB must use UE4M3 (scale_format_ is 1-bit, shared). The L1 epilogue converts float→UE4M3 for activation scales.
238 lines
11 KiB
C++
238 lines
11 KiB
C++
#pragma once
|
|
|
|
#include <functional>
|
|
#include <pybind11/functional.h>
|
|
|
|
#if DG_TENSORMAP_COMPATIBLE
|
|
#include "../jit/compiler.hpp"
|
|
#endif
|
|
#include "../jit/device_runtime.hpp"
|
|
#include "../jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp"
|
|
|
|
namespace deep_gemm::mega::nvfp4 {
|
|
|
|
static int get_token_alignment_for_mega_moe() {
|
|
return layout::kLCMCandidateBlockM;
|
|
}
|
|
|
|
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
|
|
get_symm_buffer_size_for_nvfp4_mega_moe(
|
|
const int& num_ranks, const int& num_experts,
|
|
const int& num_max_tokens_per_rank, const int& num_topk,
|
|
const int& hidden, const int& intermediate_hidden,
|
|
const bool& use_fp8_dispatch, const std::string& activation) {
|
|
DG_HOST_ASSERT(num_experts % num_ranks == 0);
|
|
|
|
// Workspace bytes
|
|
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
|
|
|
|
// NVFP4 layouts: group_size=16, so SF stride is K/16 (twice as many as MXFP4)
|
|
const auto fp8_token_layout = layout::Data(hidden);
|
|
const auto bf16_token_layout = layout::Data(hidden * 2);
|
|
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
|
|
const auto nvfp4_sf_layout = layout::Data(hidden / 16);
|
|
const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16);
|
|
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
|
|
const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false);
|
|
const auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
|
|
|
|
// Input buffers
|
|
const auto input_token_buffer = layout::Buffer(
|
|
fp8_token_layout, 1, num_max_tokens_per_rank,
|
|
workspace.get_end_ptr());
|
|
const auto input_sf_buffer = layout::Buffer(
|
|
nvfp4_sf_layout, 1, num_max_tokens_per_rank,
|
|
input_token_buffer.get_end_ptr());
|
|
const auto input_topk_idx_buffer = layout::Buffer(
|
|
input_topk_idx_layout, 1, num_max_tokens_per_rank,
|
|
input_sf_buffer.get_end_ptr());
|
|
const auto input_topk_weights_buffer = layout::Buffer(
|
|
input_topk_weights_layout, 1, num_max_tokens_per_rank,
|
|
input_topk_idx_buffer.get_end_ptr());
|
|
|
|
// Buffer configs
|
|
const auto num_max_pool_tokens = static_cast<int>(workspace.num_max_pool_tokens);
|
|
int num_max_padded_sf_pool_tokens = 0;
|
|
for (int block_m: layout::kCandidateBlockM) {
|
|
num_max_padded_sf_pool_tokens = std::max(
|
|
num_max_padded_sf_pool_tokens,
|
|
layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m)
|
|
);
|
|
}
|
|
|
|
// L1 input buffer
|
|
const auto l1_token_buffer = layout::Buffer(
|
|
fp8_token_layout, 1, num_max_pool_tokens,
|
|
input_topk_weights_buffer.get_end_ptr());
|
|
const auto l1_sf_buffer = layout::Buffer(
|
|
nvfp4_sf_layout, 1, num_max_padded_sf_pool_tokens,
|
|
l1_token_buffer.get_end_ptr());
|
|
const auto l1_topk_weights_buffer = layout::Buffer(
|
|
l1_topk_weights_layout, 1, num_max_pool_tokens,
|
|
l1_sf_buffer.get_end_ptr());
|
|
|
|
// L2 input buffer
|
|
const auto l2_token_buffer = layout::Buffer(
|
|
fp8_intermediate_token_layout, 1, num_max_pool_tokens,
|
|
l1_topk_weights_buffer.get_end_ptr());
|
|
const auto l2_sf_buffer = layout::Buffer(
|
|
nvfp4_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens,
|
|
l2_token_buffer.get_end_ptr());
|
|
|
|
// Combine input buffer: BF16 tokens for cross-rank combine
|
|
const auto combine_token_buffer = layout::Buffer(
|
|
bf16_token_layout, num_topk, num_max_tokens_per_rank,
|
|
l2_sf_buffer.get_end_ptr());
|
|
|
|
// Check SF buffer requirements
|
|
// NVFP4: hidden must be divisible by 64 (4 UE4M3 scales per int32, group_size=16)
|
|
DG_HOST_ASSERT(hidden % 64 == 0 and intermediate_hidden % 64 == 0);
|
|
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);
|
|
|
|
// Slice function
|
|
auto slice_input_buffers = [=](const torch::Tensor& buffer) {
|
|
auto x = torch::from_blob(
|
|
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
|
|
{num_max_tokens_per_rank, hidden},
|
|
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
|
|
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
|
|
auto x_sf = torch::from_blob(
|
|
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
|
|
{num_max_tokens_per_rank, hidden / 64},
|
|
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
|
|
auto topk_idx = torch::from_blob(
|
|
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)),
|
|
{num_max_tokens_per_rank, num_topk},
|
|
torch::TensorOptions().dtype(torch::kInt64).device(buffer.device()));
|
|
auto topk_weights = torch::from_blob(
|
|
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
|
|
{num_max_tokens_per_rank, num_topk},
|
|
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
|
|
auto l1_acts = torch::from_blob(
|
|
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
|
|
{num_max_pool_tokens, hidden},
|
|
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
|
|
// NVFP4 L1 SF: M-major, K/64 int32
|
|
auto l1_acts_sf = torch::from_blob(
|
|
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
|
|
{num_max_padded_sf_pool_tokens, hidden / 64},
|
|
{1, num_max_padded_sf_pool_tokens},
|
|
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
|
|
auto l2_acts = torch::from_blob(
|
|
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
|
|
{num_max_pool_tokens, intermediate_hidden},
|
|
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
|
|
// NVFP4 L2 SF: M-major, K/64 int32
|
|
auto l2_acts_sf = torch::from_blob(
|
|
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
|
|
{num_max_padded_sf_pool_tokens, intermediate_hidden / 64},
|
|
{1, num_max_padded_sf_pool_tokens},
|
|
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
|
|
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
|
|
};
|
|
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
|
|
}
|
|
|
|
static void fp8_nvfp4_mega_moe(
|
|
const torch::Tensor& y,
|
|
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_tuple,
|
|
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_tuple,
|
|
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
|
const torch::Tensor& sym_buffer,
|
|
const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx,
|
|
const int& num_max_tokens_per_rank,
|
|
const int& num_experts, const int& num_topk,
|
|
const std::tuple<int, int, int>& recipe,
|
|
const std::string& activation,
|
|
const std::optional<float>& activation_clamp_opt,
|
|
const bool& fast_math
|
|
) {
|
|
const auto [l1_weights, l1_weights_sf] = l1_weights_tuple;
|
|
const auto [l2_weights, l2_weights_sf] = l2_weights_tuple;
|
|
|
|
// Config checks
|
|
const auto num_tokens = static_cast<int>(y.size(0));
|
|
const auto [rm, rn, rk] = recipe;
|
|
DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 16); // NVFP4: group_size=16
|
|
DG_HOST_ASSERT(activation == "swiglu");
|
|
|
|
// Activation checks
|
|
const auto activation_clamp =
|
|
activation_clamp_opt.value_or(std::numeric_limits<float>::infinity());
|
|
DG_HOST_ASSERT(activation_clamp >= 0);
|
|
|
|
// Tensor checks
|
|
DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K);
|
|
DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K);
|
|
const auto arch_major = device_runtime->get_arch_major();
|
|
const auto [num_experts_per_rank, intermediate_hidden_2, hidden] =
|
|
check_grouped_ab_fp8_fp4(l1_weights, cute::UMMA::Major::K, arch_major);
|
|
const auto [num_experts_per_rank_, hidden_, intermediate_hidden] =
|
|
check_grouped_ab_fp8_fp4(l2_weights, cute::UMMA::Major::K, arch_major);
|
|
DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_);
|
|
DG_HOST_ASSERT(hidden == hidden_);
|
|
DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden);
|
|
DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous());
|
|
|
|
// Check weight SF layout for UE4M3 packing, MN-major, and TMA alignment
|
|
// NVFP4: kGranK=16, SF packed as int32 (4 UE4M3 bytes per int32)
|
|
constexpr int kGranMN = 1, kGranK = 16;
|
|
check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK,
|
|
num_experts_per_rank, true, false, torch::kInt);
|
|
check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK,
|
|
num_experts_per_rank, true, false, torch::kInt);
|
|
|
|
// Check stats counter
|
|
if (cumulative_local_expert_recv_stats.has_value()) {
|
|
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
|
|
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank);
|
|
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous());
|
|
}
|
|
|
|
// Check buffer bytes
|
|
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
|
|
const auto num_experts_ = num_experts_per_rank * num_ranks;
|
|
const auto [num_required_bytes, slice] = get_symm_buffer_size_for_nvfp4_mega_moe(
|
|
num_ranks, num_experts,
|
|
num_max_tokens_per_rank, num_topk,
|
|
hidden, intermediate_hidden,
|
|
true, activation);
|
|
DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast<size_t>(num_required_bytes));
|
|
DG_HOST_ASSERT(num_experts == num_experts_);
|
|
|
|
// Already registered tensors
|
|
const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer);
|
|
|
|
// Dispatch into different architectures
|
|
if (arch_major == 10) {
|
|
sm100_fp8_nvfp4_mega_moe(y,
|
|
l1_acts, l1_acts_sf,
|
|
l2_acts, l2_acts_sf,
|
|
l1_weights, l2_weights,
|
|
l1_weights_sf, l2_weights_sf,
|
|
cumulative_local_expert_recv_stats,
|
|
sym_buffer_ptrs,
|
|
rank_idx, num_max_tokens_per_rank,
|
|
num_experts_per_rank,
|
|
num_tokens, num_topk,
|
|
hidden, intermediate_hidden,
|
|
activation_clamp, fast_math);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture");
|
|
}
|
|
|
|
// Zero the entire symmetric buffer for debug mode
|
|
if (get_env<int>("DG_COMM_KERNEL_DEBUG"))
|
|
sym_buffer.zero_();
|
|
}
|
|
|
|
static void register_apis(pybind11::module_& m) {
|
|
#if DG_TENSORMAP_COMPATIBLE
|
|
m.def("get_token_alignment_for_nvfp4_mega_moe", &get_token_alignment_for_mega_moe);
|
|
m.def("get_symm_buffer_size_for_nvfp4_mega_moe", &get_symm_buffer_size_for_nvfp4_mega_moe);
|
|
m.def("fp8_nvfp4_mega_moe", &fp8_nvfp4_mega_moe);
|
|
#endif
|
|
}
|
|
|
|
} // namespace deep_gemm::mega::nvfp4
|