feat: NVFP4 mega MoE kernel (scale_vec::4X, UE4M3 block scales)
- New CUDA kernel: sm100_fp8_nvfp4_mega_moe_impl - kGranK=16 (NVFP4 group_size=16, vs MXFP4's 32) - kind::mxf4nvf4.block_scale.scale_vec::4X PTX instruction - float_ue4m3_t scale factor type in instruction descriptor - SF layout: scale_vec::4X (4 TMEM sub-columns per UMMA atom) - UTCCP column stride: i*8 (vs MXFP4's i*4) for 4X layout - L1 epilogue: UE4M3 activation scales (float→cutlass::float_e4m3_t) - SF loading: kNumSFUint32 = kHidden/64 (4 UE4M3 per int32) - New PTX wrappers: SM100_MMA_MXF4NVF4_2x1SM_SS, SM100_MMA_MXF4NVF4_SS - Python API: - fp8_nvfp4_mega_moe() with recipe=(1,1,16) - transform_nvfp4_weights_for_mega_moe() for UE4M3→int32 UTCCP packing - _pack_nvfp4_sf_for_utccp() helper - C++ bindings: - mega_nvfp4.hpp with NVFP4-specific SymmBuffer (SF stride K/16) - JIT kernel header with kGranK=16 TMA descriptors - Registered in python_api.cpp NOTE: Both SFA and SFB must use UE4M3 (scale_format_ is 1-bit, shared). The L1 epilogue converts float→UE4M3 for activation scales.
This commit is contained in:
237
csrc/apis/mega_nvfp4.hpp
Normal file
237
csrc/apis/mega_nvfp4.hpp
Normal file
@@ -0,0 +1,237 @@
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <pybind11/functional.h>
|
||||
|
||||
#if DG_TENSORMAP_COMPATIBLE
|
||||
#include "../jit/compiler.hpp"
|
||||
#endif
|
||||
#include "../jit/device_runtime.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp"
|
||||
|
||||
namespace deep_gemm::mega::nvfp4 {
|
||||
|
||||
static int get_token_alignment_for_mega_moe() {
|
||||
return layout::kLCMCandidateBlockM;
|
||||
}
|
||||
|
||||
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
|
||||
get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
const int& num_ranks, const int& num_experts,
|
||||
const int& num_max_tokens_per_rank, const int& num_topk,
|
||||
const int& hidden, const int& intermediate_hidden,
|
||||
const bool& use_fp8_dispatch, const std::string& activation) {
|
||||
DG_HOST_ASSERT(num_experts % num_ranks == 0);
|
||||
|
||||
// Workspace bytes
|
||||
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
|
||||
|
||||
// NVFP4 layouts: group_size=16, so SF stride is K/16 (twice as many as MXFP4)
|
||||
const auto fp8_token_layout = layout::Data(hidden);
|
||||
const auto bf16_token_layout = layout::Data(hidden * 2);
|
||||
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
|
||||
const auto nvfp4_sf_layout = layout::Data(hidden / 16);
|
||||
const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16);
|
||||
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
|
||||
const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false);
|
||||
const auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
|
||||
|
||||
// Input buffers
|
||||
const auto input_token_buffer = layout::Buffer(
|
||||
fp8_token_layout, 1, num_max_tokens_per_rank,
|
||||
workspace.get_end_ptr());
|
||||
const auto input_sf_buffer = layout::Buffer(
|
||||
nvfp4_sf_layout, 1, num_max_tokens_per_rank,
|
||||
input_token_buffer.get_end_ptr());
|
||||
const auto input_topk_idx_buffer = layout::Buffer(
|
||||
input_topk_idx_layout, 1, num_max_tokens_per_rank,
|
||||
input_sf_buffer.get_end_ptr());
|
||||
const auto input_topk_weights_buffer = layout::Buffer(
|
||||
input_topk_weights_layout, 1, num_max_tokens_per_rank,
|
||||
input_topk_idx_buffer.get_end_ptr());
|
||||
|
||||
// Buffer configs
|
||||
const auto num_max_pool_tokens = static_cast<int>(workspace.num_max_pool_tokens);
|
||||
int num_max_padded_sf_pool_tokens = 0;
|
||||
for (int block_m: layout::kCandidateBlockM) {
|
||||
num_max_padded_sf_pool_tokens = std::max(
|
||||
num_max_padded_sf_pool_tokens,
|
||||
layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m)
|
||||
);
|
||||
}
|
||||
|
||||
// L1 input buffer
|
||||
const auto l1_token_buffer = layout::Buffer(
|
||||
fp8_token_layout, 1, num_max_pool_tokens,
|
||||
input_topk_weights_buffer.get_end_ptr());
|
||||
const auto l1_sf_buffer = layout::Buffer(
|
||||
nvfp4_sf_layout, 1, num_max_padded_sf_pool_tokens,
|
||||
l1_token_buffer.get_end_ptr());
|
||||
const auto l1_topk_weights_buffer = layout::Buffer(
|
||||
l1_topk_weights_layout, 1, num_max_pool_tokens,
|
||||
l1_sf_buffer.get_end_ptr());
|
||||
|
||||
// L2 input buffer
|
||||
const auto l2_token_buffer = layout::Buffer(
|
||||
fp8_intermediate_token_layout, 1, num_max_pool_tokens,
|
||||
l1_topk_weights_buffer.get_end_ptr());
|
||||
const auto l2_sf_buffer = layout::Buffer(
|
||||
nvfp4_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens,
|
||||
l2_token_buffer.get_end_ptr());
|
||||
|
||||
// Combine input buffer: BF16 tokens for cross-rank combine
|
||||
const auto combine_token_buffer = layout::Buffer(
|
||||
bf16_token_layout, num_topk, num_max_tokens_per_rank,
|
||||
l2_sf_buffer.get_end_ptr());
|
||||
|
||||
// Check SF buffer requirements
|
||||
// NVFP4: hidden must be divisible by 64 (4 UE4M3 scales per int32, group_size=16)
|
||||
DG_HOST_ASSERT(hidden % 64 == 0 and intermediate_hidden % 64 == 0);
|
||||
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);
|
||||
|
||||
// Slice function
|
||||
auto slice_input_buffers = [=](const torch::Tensor& buffer) {
|
||||
auto x = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
|
||||
{num_max_tokens_per_rank, hidden},
|
||||
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
|
||||
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
|
||||
auto x_sf = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
|
||||
{num_max_tokens_per_rank, hidden / 64},
|
||||
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
|
||||
auto topk_idx = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)),
|
||||
{num_max_tokens_per_rank, num_topk},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(buffer.device()));
|
||||
auto topk_weights = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
|
||||
{num_max_tokens_per_rank, num_topk},
|
||||
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
|
||||
auto l1_acts = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
|
||||
{num_max_pool_tokens, hidden},
|
||||
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
|
||||
// NVFP4 L1 SF: M-major, K/64 int32
|
||||
auto l1_acts_sf = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
|
||||
{num_max_padded_sf_pool_tokens, hidden / 64},
|
||||
{1, num_max_padded_sf_pool_tokens},
|
||||
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
|
||||
auto l2_acts = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
|
||||
{num_max_pool_tokens, intermediate_hidden},
|
||||
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
|
||||
// NVFP4 L2 SF: M-major, K/64 int32
|
||||
auto l2_acts_sf = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
|
||||
{num_max_padded_sf_pool_tokens, intermediate_hidden / 64},
|
||||
{1, num_max_padded_sf_pool_tokens},
|
||||
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
|
||||
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
|
||||
};
|
||||
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
|
||||
}
|
||||
|
||||
static void fp8_nvfp4_mega_moe(
|
||||
const torch::Tensor& y,
|
||||
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_tuple,
|
||||
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_tuple,
|
||||
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
|
||||
const torch::Tensor& sym_buffer,
|
||||
const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx,
|
||||
const int& num_max_tokens_per_rank,
|
||||
const int& num_experts, const int& num_topk,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::string& activation,
|
||||
const std::optional<float>& activation_clamp_opt,
|
||||
const bool& fast_math
|
||||
) {
|
||||
const auto [l1_weights, l1_weights_sf] = l1_weights_tuple;
|
||||
const auto [l2_weights, l2_weights_sf] = l2_weights_tuple;
|
||||
|
||||
// Config checks
|
||||
const auto num_tokens = static_cast<int>(y.size(0));
|
||||
const auto [rm, rn, rk] = recipe;
|
||||
DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 16); // NVFP4: group_size=16
|
||||
DG_HOST_ASSERT(activation == "swiglu");
|
||||
|
||||
// Activation checks
|
||||
const auto activation_clamp =
|
||||
activation_clamp_opt.value_or(std::numeric_limits<float>::infinity());
|
||||
DG_HOST_ASSERT(activation_clamp >= 0);
|
||||
|
||||
// Tensor checks
|
||||
DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K);
|
||||
const auto arch_major = device_runtime->get_arch_major();
|
||||
const auto [num_experts_per_rank, intermediate_hidden_2, hidden] =
|
||||
check_grouped_ab_fp8_fp4(l1_weights, cute::UMMA::Major::K, arch_major);
|
||||
const auto [num_experts_per_rank_, hidden_, intermediate_hidden] =
|
||||
check_grouped_ab_fp8_fp4(l2_weights, cute::UMMA::Major::K, arch_major);
|
||||
DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_);
|
||||
DG_HOST_ASSERT(hidden == hidden_);
|
||||
DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden);
|
||||
DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous());
|
||||
|
||||
// Check weight SF layout for UE4M3 packing, MN-major, and TMA alignment
|
||||
// NVFP4: kGranK=16, SF packed as int32 (4 UE4M3 bytes per int32)
|
||||
constexpr int kGranMN = 1, kGranK = 16;
|
||||
check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK,
|
||||
num_experts_per_rank, true, false, torch::kInt);
|
||||
check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK,
|
||||
num_experts_per_rank, true, false, torch::kInt);
|
||||
|
||||
// Check stats counter
|
||||
if (cumulative_local_expert_recv_stats.has_value()) {
|
||||
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank);
|
||||
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous());
|
||||
}
|
||||
|
||||
// Check buffer bytes
|
||||
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
|
||||
const auto num_experts_ = num_experts_per_rank * num_ranks;
|
||||
const auto [num_required_bytes, slice] = get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
num_ranks, num_experts,
|
||||
num_max_tokens_per_rank, num_topk,
|
||||
hidden, intermediate_hidden,
|
||||
true, activation);
|
||||
DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast<size_t>(num_required_bytes));
|
||||
DG_HOST_ASSERT(num_experts == num_experts_);
|
||||
|
||||
// Already registered tensors
|
||||
const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer);
|
||||
|
||||
// Dispatch into different architectures
|
||||
if (arch_major == 10) {
|
||||
sm100_fp8_nvfp4_mega_moe(y,
|
||||
l1_acts, l1_acts_sf,
|
||||
l2_acts, l2_acts_sf,
|
||||
l1_weights, l2_weights,
|
||||
l1_weights_sf, l2_weights_sf,
|
||||
cumulative_local_expert_recv_stats,
|
||||
sym_buffer_ptrs,
|
||||
rank_idx, num_max_tokens_per_rank,
|
||||
num_experts_per_rank,
|
||||
num_tokens, num_topk,
|
||||
hidden, intermediate_hidden,
|
||||
activation_clamp, fast_math);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
|
||||
// Zero the entire symmetric buffer for debug mode
|
||||
if (get_env<int>("DG_COMM_KERNEL_DEBUG"))
|
||||
sym_buffer.zero_();
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
#if DG_TENSORMAP_COMPATIBLE
|
||||
m.def("get_token_alignment_for_nvfp4_mega_moe", &get_token_alignment_for_mega_moe);
|
||||
m.def("get_symm_buffer_size_for_nvfp4_mega_moe", &get_symm_buffer_size_for_nvfp4_mega_moe);
|
||||
m.def("fp8_nvfp4_mega_moe", &fp8_nvfp4_mega_moe);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::mega::nvfp4
|
||||
220
csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp
Normal file
220
csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp
Normal file
@@ -0,0 +1,220 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
#include <deep_gemm/layout/mega_moe.cuh>
|
||||
#include <deep_gemm/layout/sym_buffer.cuh>
|
||||
|
||||
#include "../heuristics/mega_moe.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100FP8NVFP4MegaMoERuntime final : public LaunchRuntime<SM100FP8NVFP4MegaMoERuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
// Templated arguments
|
||||
int num_max_tokens_per_rank;
|
||||
int hidden, intermediate_hidden;
|
||||
int num_experts, num_topk;
|
||||
int num_ranks;
|
||||
float activation_clamp;
|
||||
bool fast_math;
|
||||
MegaMoEConfig config;
|
||||
|
||||
// Runtime arguments
|
||||
void* y;
|
||||
int* cumulative_local_expert_recv_stats;
|
||||
int num_tokens;
|
||||
layout::SymBuffer<> sym_buffer_ptrs;
|
||||
|
||||
// Tensormap
|
||||
CUtensorMap tensor_map_l1_acts;
|
||||
CUtensorMap tensor_map_l1_acts_sf;
|
||||
CUtensorMap tensor_map_l1_weights;
|
||||
CUtensorMap tensor_map_l1_weights_sf;
|
||||
CUtensorMap tensor_map_l1_output;
|
||||
CUtensorMap tensor_map_l2_acts;
|
||||
CUtensorMap tensor_map_l2_acts_sf;
|
||||
CUtensorMap tensor_map_l2_weights;
|
||||
CUtensorMap tensor_map_l2_weights_sf;
|
||||
|
||||
// Launch configs
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_fp8_nvfp4_mega_moe_impl<
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)", args.num_max_tokens_per_rank,
|
||||
args.hidden, args.intermediate_hidden,
|
||||
args.num_experts, args.num_topk,
|
||||
args.config.num_experts_per_wave,
|
||||
args.config.block_m, args.config.block_n, args.config.block_k,
|
||||
args.config.store_block_m,
|
||||
args.config.sf_block_m, args.config.sf_block_n,
|
||||
args.config.num_max_pool_tokens,
|
||||
args.config.num_padded_sf_pool_tokens,
|
||||
args.config.num_stages,
|
||||
args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads,
|
||||
args.launch_args.grid_dim.first, args.num_ranks,
|
||||
to_string(args.activation_clamp),
|
||||
args.fast_math ? "true" : "false");
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.y,
|
||||
args.cumulative_local_expert_recv_stats,
|
||||
args.num_tokens,
|
||||
args.sym_buffer_ptrs,
|
||||
args.tensor_map_l1_acts,
|
||||
args.tensor_map_l1_acts_sf,
|
||||
args.tensor_map_l1_weights,
|
||||
args.tensor_map_l1_weights_sf,
|
||||
args.tensor_map_l1_output,
|
||||
args.tensor_map_l2_acts,
|
||||
args.tensor_map_l2_acts_sf,
|
||||
args.tensor_map_l2_weights,
|
||||
args.tensor_map_l2_weights_sf
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm100_fp8_nvfp4_mega_moe(
|
||||
const torch::Tensor& y,
|
||||
const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf,
|
||||
const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf,
|
||||
const torch::Tensor& l1_weights, const torch::Tensor& l2_weights,
|
||||
const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf,
|
||||
const std::optional<torch::Tensor> cumulative_local_expert_recv_stats,
|
||||
const std::vector<int64_t>& sym_buffer_ptrs,
|
||||
const int& rank_idx, const int& num_max_tokens_per_rank,
|
||||
const int& num_experts_per_rank,
|
||||
const int& num_tokens, const int& num_topk,
|
||||
const int& hidden, const int& intermediate_hidden,
|
||||
const float& activation_clamp,
|
||||
const bool& fast_math
|
||||
) {
|
||||
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
|
||||
const auto num_experts = num_experts_per_rank * num_ranks;
|
||||
const auto num_padded_sf_pool_tokens = static_cast<int>(l1_acts_sf.size(0));
|
||||
|
||||
// Heuristics
|
||||
const auto config = get_mega_moe_config(
|
||||
num_ranks, num_experts, num_experts_per_rank,
|
||||
num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden, num_padded_sf_pool_tokens);
|
||||
|
||||
// NVFP4: kGranK=16 for group_size=16
|
||||
constexpr int kGranK = 16;
|
||||
|
||||
// Make tensormap — weight/activation TMA descriptors are the same as MXFP4
|
||||
// (E2M1 packed uint8 is the same format regardless of scale type)
|
||||
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
|
||||
hidden, config.num_max_pool_tokens,
|
||||
config.block_k, config.load_block_m,
|
||||
static_cast<int>(l1_acts.stride(-2)),
|
||||
config.swizzle_acts_mode);
|
||||
// NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32
|
||||
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
|
||||
config.num_padded_sf_pool_tokens, hidden,
|
||||
config.sf_block_m, kGranK,
|
||||
1, 0);
|
||||
const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights,
|
||||
hidden, num_experts_per_rank * intermediate_hidden * 2,
|
||||
config.block_k, config.load_block_n,
|
||||
static_cast<int>(l1_weights.stride(-2)),
|
||||
config.swizzle_weights_mode);
|
||||
const auto tensor_map_l1_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_weights_sf,
|
||||
intermediate_hidden * 2, hidden,
|
||||
config.block_n, kGranK,
|
||||
num_experts_per_rank, 0);
|
||||
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
|
||||
intermediate_hidden, config.num_max_pool_tokens,
|
||||
config.block_n / 2, config.store_block_m,
|
||||
static_cast<int>(l2_acts.stride(-2)),
|
||||
config.swizzle_acts_mode / 2);
|
||||
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
|
||||
intermediate_hidden, config.num_max_pool_tokens,
|
||||
config.block_k, config.load_block_m,
|
||||
static_cast<int>(l2_acts.stride(-2)),
|
||||
config.swizzle_acts_mode);
|
||||
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
|
||||
config.num_padded_sf_pool_tokens, intermediate_hidden,
|
||||
config.sf_block_m, kGranK,
|
||||
1, 0);
|
||||
const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights,
|
||||
intermediate_hidden, num_experts_per_rank * hidden,
|
||||
config.block_k, config.load_block_n,
|
||||
static_cast<int>(l2_weights.stride(-2)),
|
||||
config.swizzle_weights_mode);
|
||||
const auto tensor_map_l2_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_weights_sf,
|
||||
hidden, intermediate_hidden,
|
||||
config.block_n, kGranK,
|
||||
num_experts_per_rank, 0);
|
||||
|
||||
// Stats can be optional
|
||||
int* cumulative_local_expert_recv_stats_ptr = nullptr;
|
||||
if (cumulative_local_expert_recv_stats.has_value())
|
||||
cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr<int>();
|
||||
|
||||
// Launch
|
||||
const auto num_sms = device_runtime->get_num_sms();
|
||||
const SM100FP8NVFP4MegaMoERuntime::Args args = {
|
||||
.num_max_tokens_per_rank = num_max_tokens_per_rank,
|
||||
.hidden = hidden, .intermediate_hidden = intermediate_hidden,
|
||||
.num_experts = num_experts, .num_topk = num_topk,
|
||||
.num_ranks = num_ranks,
|
||||
.activation_clamp = activation_clamp,
|
||||
.fast_math = fast_math,
|
||||
.config = config,
|
||||
.y = y.data_ptr(),
|
||||
.cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr,
|
||||
.num_tokens = num_tokens,
|
||||
.sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx),
|
||||
.tensor_map_l1_acts = tensor_map_l1_acts,
|
||||
.tensor_map_l1_acts_sf = tensor_map_l1_acts_sf,
|
||||
.tensor_map_l1_weights = tensor_map_l1_weights,
|
||||
.tensor_map_l1_weights_sf = tensor_map_l1_weights_sf,
|
||||
.tensor_map_l1_output = tensor_map_l1_output,
|
||||
.tensor_map_l2_acts = tensor_map_l2_acts,
|
||||
.tensor_map_l2_acts_sf = tensor_map_l2_acts_sf,
|
||||
.tensor_map_l2_weights = tensor_map_l2_weights,
|
||||
.tensor_map_l2_weights_sf = tensor_map_l2_weights_sf,
|
||||
.launch_args = LaunchArgs(num_sms,
|
||||
config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads,
|
||||
config.smem_size, 2)
|
||||
};
|
||||
|
||||
const auto code = SM100FP8NVFP4MegaMoERuntime::generate(args);
|
||||
const auto runtime = compiler->build("sm100_fp8_nvfp4_mega_moe", code);
|
||||
SM100FP8NVFP4MegaMoERuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "apis/gemm.hpp"
|
||||
#include "apis/layout.hpp"
|
||||
#include "apis/mega.hpp"
|
||||
#include "apis/mega_nvfp4.hpp"
|
||||
#include "apis/runtime.hpp"
|
||||
|
||||
#ifndef TORCH_EXTENSION_NAME
|
||||
@@ -24,5 +25,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
deep_gemm::gemm::register_apis(m);
|
||||
deep_gemm::layout::register_apis(m);
|
||||
deep_gemm::mega::register_apis(m);
|
||||
deep_gemm::mega::nvfp4::register_apis(m);
|
||||
deep_gemm::runtime::register_apis(m);
|
||||
}
|
||||
|
||||
1398
deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh
Normal file
1398
deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh
Normal file
File diff suppressed because it is too large
Load Diff
@@ -139,6 +139,50 @@ struct SM100_MMA_MXF4_SS {
|
||||
}
|
||||
};
|
||||
|
||||
/// NVFP4 block-scaled MMA (scale_vec::4X = block16, UE4M3 scales)
|
||||
struct SM100_MMA_MXF4NVF4_2x1SM_SS {
|
||||
CUTLASS_DEVICE static void
|
||||
fma(uint64_t const& desc_a,
|
||||
uint64_t const& desc_b,
|
||||
uint32_t const& tmem_c,
|
||||
uint32_t const& scale_c,
|
||||
uint64_t const& desc,
|
||||
uint32_t const& tmem_sfa,
|
||||
uint32_t const& tmem_sfb) {
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
|
||||
"}\n"
|
||||
:
|
||||
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
|
||||
"r"(tmem_sfa), "r"(tmem_sfb));
|
||||
}
|
||||
};
|
||||
|
||||
/// NVFP4 block-scaled MMA (single CTA, scale_vec::4X = block16, UE4M3 scales)
|
||||
struct SM100_MMA_MXF4NVF4_SS {
|
||||
CUTLASS_DEVICE static void
|
||||
fma(uint64_t const& desc_a,
|
||||
uint64_t const& desc_b,
|
||||
uint32_t const& tmem_c,
|
||||
uint32_t const& scale_c,
|
||||
uint64_t const& desc,
|
||||
uint32_t const& tmem_sfa,
|
||||
uint32_t const& tmem_sfb) {
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
|
||||
"}\n"
|
||||
:
|
||||
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
|
||||
"r"(tmem_sfa), "r"(tmem_sfb));
|
||||
}
|
||||
};
|
||||
|
||||
struct SM100_MMA_F16BF16_WS_SS {
|
||||
CUTLASS_DEVICE static void
|
||||
fma(uint64_t const& desc_a,
|
||||
|
||||
@@ -105,6 +105,57 @@ def transform_weights_for_mega_moe(
|
||||
return l1_weights, l2_weights
|
||||
|
||||
|
||||
def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack NVFP4 UE4M3 block scales (float8_e4m3fn) into int32 UTCCP layout.
|
||||
|
||||
NVFP4 uses UE4M3 scales with group_size=16 (scale_vec::4X).
|
||||
The UTCCP layout packs 4 consecutive scale bytes into each int32,
|
||||
then applies the 4x32 transpose for TMA consumption.
|
||||
|
||||
Input: (num_experts, mn, K//16) float8_e4m3fn scales
|
||||
Output: (num_experts, mn, K//64) int32 packed UTCCP-transposed scales
|
||||
"""
|
||||
num_groups, mn, sf_k = sf.shape
|
||||
assert sf_k % 4 == 0, f"NVFP4 SF K dim must be divisible by 4, got {sf_k}"
|
||||
assert mn % 128 == 0, f"MN dim must be divisible by 128, got {mn}"
|
||||
|
||||
# View as uint8 and pack 4 consecutive bytes into int32
|
||||
sf_uint8 = sf.view(torch.uint8) # (num_groups, mn, sf_k)
|
||||
# Pack: every 4 uint8 → 1 int32
|
||||
packed = (sf_uint8[..., 0::4].to(torch.int32) |
|
||||
(sf_uint8[..., 1::4].to(torch.int32) << 8) |
|
||||
(sf_uint8[..., 2::4].to(torch.int32) << 16) |
|
||||
(sf_uint8[..., 3::4].to(torch.int32) << 24)) # (num_groups, mn, sf_k//4)
|
||||
|
||||
# Apply UTCCP 4x32 transpose (same as MXFP4 — the transpose is determined
|
||||
# by the 128-element alignment, not the scale vector size)
|
||||
packed_sf_k = sf_k // 4
|
||||
result = (packed.reshape(num_groups, -1, 4, 32, packed_sf_k)
|
||||
.transpose(2, 3)
|
||||
.reshape(num_groups, mn, packed_sf_k))
|
||||
return torch.empty_like(packed).copy_(result)
|
||||
|
||||
|
||||
def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
l2_weights: Tuple[torch.Tensor, torch.Tensor]
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Transform NVFP4 expert weights for the mega_moe kernel.
|
||||
|
||||
NVFP4 weights come as (weight, scale) where:
|
||||
- weight: uint8 E2M1 packed, shape (num_experts, N, K//2)
|
||||
- scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16)
|
||||
|
||||
The kernel expects (weight, packed_sf) where packed_sf is int32 UTCCP layout.
|
||||
"""
|
||||
# L1: interleave gate/up, then pack + transpose SF for UTCCP
|
||||
l1_interleaved = _interleave_l1_weights(l1_weights)
|
||||
l1_weights = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
|
||||
# L2: only pack + transpose SF for UTCCP
|
||||
l2_weights = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_weights[1]))
|
||||
return l1_weights, l2_weights
|
||||
|
||||
|
||||
def fp8_fp4_mega_moe(y: torch.Tensor,
|
||||
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
@@ -126,3 +177,32 @@ def fp8_fp4_mega_moe(y: torch.Tensor,
|
||||
activation, activation_clamp,
|
||||
fast_math
|
||||
)
|
||||
|
||||
|
||||
def fp8_nvfp4_mega_moe(y: torch.Tensor,
|
||||
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
sym_buffer: SymmBuffer,
|
||||
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
|
||||
recipe: Tuple[int, int, int] = (1, 1, 16),
|
||||
activation: str = 'swiglu',
|
||||
activation_clamp: Optional[float] = None,
|
||||
fast_math: bool = True):
|
||||
"""NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X
|
||||
with UE4M3 block scales (group_size=16).
|
||||
|
||||
Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales)
|
||||
Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16.
|
||||
"""
|
||||
_C.fp8_nvfp4_mega_moe(
|
||||
y,
|
||||
l1_weights, l2_weights,
|
||||
cumulative_local_expert_recv_stats,
|
||||
sym_buffer.buffer,
|
||||
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
|
||||
sym_buffer.num_max_tokens_per_rank,
|
||||
sym_buffer.num_experts, sym_buffer.num_topk,
|
||||
recipe,
|
||||
activation, activation_clamp,
|
||||
fast_math
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user