feat: NVFP4 mega MoE kernel (scale_vec::4X, UE4M3 block scales)

- New CUDA kernel: sm100_fp8_nvfp4_mega_moe_impl
  - kGranK=16 (NVFP4 group_size=16, vs MXFP4's 32)
  - kind::mxf4nvf4.block_scale.scale_vec::4X PTX instruction
  - float_ue4m3_t scale factor type in instruction descriptor
  - SF layout: scale_vec::4X (4 TMEM sub-columns per UMMA atom)
  - UTCCP column stride: i*8 (vs MXFP4's i*4) for 4X layout
  - L1 epilogue: UE4M3 activation scales (float→cutlass::float_e4m3_t)
  - SF loading: kNumSFUint32 = kHidden/64 (4 UE4M3 per int32)

- New PTX wrappers: SM100_MMA_MXF4NVF4_2x1SM_SS, SM100_MMA_MXF4NVF4_SS

- Python API:
  - fp8_nvfp4_mega_moe() with recipe=(1,1,16)
  - transform_nvfp4_weights_for_mega_moe() for UE4M3→int32 UTCCP packing
  - _pack_nvfp4_sf_for_utccp() helper

- C++ bindings:
  - mega_nvfp4.hpp with NVFP4-specific SymmBuffer (SF stride K/16)
  - JIT kernel header with kGranK=16 TMA descriptors
  - Registered in python_api.cpp

NOTE: Both SFA and SFB must use UE4M3 (scale_format_ is 1-bit, shared).
The L1 epilogue converts float→UE4M3 for activation scales.
This commit is contained in:
2026-05-11 05:41:08 +00:00
parent 891d57b4db
commit 36b439ee26
6 changed files with 1981 additions and 0 deletions

237
csrc/apis/mega_nvfp4.hpp Normal file
View File

@@ -0,0 +1,237 @@
#pragma once
#include <functional>
#include <pybind11/functional.h>
#if DG_TENSORMAP_COMPATIBLE
#include "../jit/compiler.hpp"
#endif
#include "../jit/device_runtime.hpp"
#include "../jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp"
namespace deep_gemm::mega::nvfp4 {
static int get_token_alignment_for_mega_moe() {
return layout::kLCMCandidateBlockM;
}
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
get_symm_buffer_size_for_nvfp4_mega_moe(
const int& num_ranks, const int& num_experts,
const int& num_max_tokens_per_rank, const int& num_topk,
const int& hidden, const int& intermediate_hidden,
const bool& use_fp8_dispatch, const std::string& activation) {
DG_HOST_ASSERT(num_experts % num_ranks == 0);
// Workspace bytes
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
// NVFP4 layouts: group_size=16, so SF stride is K/16 (twice as many as MXFP4)
const auto fp8_token_layout = layout::Data(hidden);
const auto bf16_token_layout = layout::Data(hidden * 2);
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
const auto nvfp4_sf_layout = layout::Data(hidden / 16);
const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16);
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false);
const auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
// Input buffers
const auto input_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_tokens_per_rank,
workspace.get_end_ptr());
const auto input_sf_buffer = layout::Buffer(
nvfp4_sf_layout, 1, num_max_tokens_per_rank,
input_token_buffer.get_end_ptr());
const auto input_topk_idx_buffer = layout::Buffer(
input_topk_idx_layout, 1, num_max_tokens_per_rank,
input_sf_buffer.get_end_ptr());
const auto input_topk_weights_buffer = layout::Buffer(
input_topk_weights_layout, 1, num_max_tokens_per_rank,
input_topk_idx_buffer.get_end_ptr());
// Buffer configs
const auto num_max_pool_tokens = static_cast<int>(workspace.num_max_pool_tokens);
int num_max_padded_sf_pool_tokens = 0;
for (int block_m: layout::kCandidateBlockM) {
num_max_padded_sf_pool_tokens = std::max(
num_max_padded_sf_pool_tokens,
layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m)
);
}
// L1 input buffer
const auto l1_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_pool_tokens,
input_topk_weights_buffer.get_end_ptr());
const auto l1_sf_buffer = layout::Buffer(
nvfp4_sf_layout, 1, num_max_padded_sf_pool_tokens,
l1_token_buffer.get_end_ptr());
const auto l1_topk_weights_buffer = layout::Buffer(
l1_topk_weights_layout, 1, num_max_pool_tokens,
l1_sf_buffer.get_end_ptr());
// L2 input buffer
const auto l2_token_buffer = layout::Buffer(
fp8_intermediate_token_layout, 1, num_max_pool_tokens,
l1_topk_weights_buffer.get_end_ptr());
const auto l2_sf_buffer = layout::Buffer(
nvfp4_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens,
l2_token_buffer.get_end_ptr());
// Combine input buffer: BF16 tokens for cross-rank combine
const auto combine_token_buffer = layout::Buffer(
bf16_token_layout, num_topk, num_max_tokens_per_rank,
l2_sf_buffer.get_end_ptr());
// Check SF buffer requirements
// NVFP4: hidden must be divisible by 64 (4 UE4M3 scales per int32, group_size=16)
DG_HOST_ASSERT(hidden % 64 == 0 and intermediate_hidden % 64 == 0);
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);
// Slice function
auto slice_input_buffers = [=](const torch::Tensor& buffer) {
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
{num_max_tokens_per_rank, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
auto x_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
{num_max_tokens_per_rank, hidden / 64},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
auto topk_idx = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kInt64).device(buffer.device()));
auto topk_weights = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
auto l1_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
{num_max_pool_tokens, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
// NVFP4 L1 SF: M-major, K/64 int32
auto l1_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, hidden / 64},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
{num_max_pool_tokens, intermediate_hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
// NVFP4 L2 SF: M-major, K/64 int32
auto l2_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, intermediate_hidden / 64},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
};
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
}
static void fp8_nvfp4_mega_moe(
const torch::Tensor& y,
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_tuple,
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_tuple,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const torch::Tensor& sym_buffer,
const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx,
const int& num_max_tokens_per_rank,
const int& num_experts, const int& num_topk,
const std::tuple<int, int, int>& recipe,
const std::string& activation,
const std::optional<float>& activation_clamp_opt,
const bool& fast_math
) {
const auto [l1_weights, l1_weights_sf] = l1_weights_tuple;
const auto [l2_weights, l2_weights_sf] = l2_weights_tuple;
// Config checks
const auto num_tokens = static_cast<int>(y.size(0));
const auto [rm, rn, rk] = recipe;
DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 16); // NVFP4: group_size=16
DG_HOST_ASSERT(activation == "swiglu");
// Activation checks
const auto activation_clamp =
activation_clamp_opt.value_or(std::numeric_limits<float>::infinity());
DG_HOST_ASSERT(activation_clamp >= 0);
// Tensor checks
DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K);
DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K);
const auto arch_major = device_runtime->get_arch_major();
const auto [num_experts_per_rank, intermediate_hidden_2, hidden] =
check_grouped_ab_fp8_fp4(l1_weights, cute::UMMA::Major::K, arch_major);
const auto [num_experts_per_rank_, hidden_, intermediate_hidden] =
check_grouped_ab_fp8_fp4(l2_weights, cute::UMMA::Major::K, arch_major);
DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_);
DG_HOST_ASSERT(hidden == hidden_);
DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden);
DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous());
// Check weight SF layout for UE4M3 packing, MN-major, and TMA alignment
// NVFP4: kGranK=16, SF packed as int32 (4 UE4M3 bytes per int32)
constexpr int kGranMN = 1, kGranK = 16;
check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK,
num_experts_per_rank, true, false, torch::kInt);
check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK,
num_experts_per_rank, true, false, torch::kInt);
// Check stats counter
if (cumulative_local_expert_recv_stats.has_value()) {
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank);
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous());
}
// Check buffer bytes
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
const auto num_experts_ = num_experts_per_rank * num_ranks;
const auto [num_required_bytes, slice] = get_symm_buffer_size_for_nvfp4_mega_moe(
num_ranks, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
true, activation);
DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast<size_t>(num_required_bytes));
DG_HOST_ASSERT(num_experts == num_experts_);
// Already registered tensors
const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer);
// Dispatch into different architectures
if (arch_major == 10) {
sm100_fp8_nvfp4_mega_moe(y,
l1_acts, l1_acts_sf,
l2_acts, l2_acts_sf,
l1_weights, l2_weights,
l1_weights_sf, l2_weights_sf,
cumulative_local_expert_recv_stats,
sym_buffer_ptrs,
rank_idx, num_max_tokens_per_rank,
num_experts_per_rank,
num_tokens, num_topk,
hidden, intermediate_hidden,
activation_clamp, fast_math);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
// Zero the entire symmetric buffer for debug mode
if (get_env<int>("DG_COMM_KERNEL_DEBUG"))
sym_buffer.zero_();
}
static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("get_token_alignment_for_nvfp4_mega_moe", &get_token_alignment_for_mega_moe);
m.def("get_symm_buffer_size_for_nvfp4_mega_moe", &get_symm_buffer_size_for_nvfp4_mega_moe);
m.def("fp8_nvfp4_mega_moe", &fp8_nvfp4_mega_moe);
#endif
}
} // namespace deep_gemm::mega::nvfp4

View File

@@ -0,0 +1,220 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "runtime_utils.hpp"
#include <deep_gemm/layout/mega_moe.cuh>
#include <deep_gemm/layout/sym_buffer.cuh>
#include "../heuristics/mega_moe.hpp"
namespace deep_gemm {
class SM100FP8NVFP4MegaMoERuntime final : public LaunchRuntime<SM100FP8NVFP4MegaMoERuntime> {
public:
struct Args {
// Templated arguments
int num_max_tokens_per_rank;
int hidden, intermediate_hidden;
int num_experts, num_topk;
int num_ranks;
float activation_clamp;
bool fast_math;
MegaMoEConfig config;
// Runtime arguments
void* y;
int* cumulative_local_expert_recv_stats;
int num_tokens;
layout::SymBuffer<> sym_buffer_ptrs;
// Tensormap
CUtensorMap tensor_map_l1_acts;
CUtensorMap tensor_map_l1_acts_sf;
CUtensorMap tensor_map_l1_weights;
CUtensorMap tensor_map_l1_weights_sf;
CUtensorMap tensor_map_l1_output;
CUtensorMap tensor_map_l2_acts;
CUtensorMap tensor_map_l2_acts_sf;
CUtensorMap tensor_map_l2_weights;
CUtensorMap tensor_map_l2_weights_sf;
// Launch configs
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_nvfp4_mega_moe_impl<
{},
{}, {},
{}, {},
{},
{}, {}, {},
{},
{}, {},
{},
{},
{},
{}, {}, {},
{}, {},
{},
{}
>);
}};
)", args.num_max_tokens_per_rank,
args.hidden, args.intermediate_hidden,
args.num_experts, args.num_topk,
args.config.num_experts_per_wave,
args.config.block_m, args.config.block_n, args.config.block_k,
args.config.store_block_m,
args.config.sf_block_m, args.config.sf_block_n,
args.config.num_max_pool_tokens,
args.config.num_padded_sf_pool_tokens,
args.config.num_stages,
args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads,
args.launch_args.grid_dim.first, args.num_ranks,
to_string(args.activation_clamp),
args.fast_math ? "true" : "false");
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.y,
args.cumulative_local_expert_recv_stats,
args.num_tokens,
args.sym_buffer_ptrs,
args.tensor_map_l1_acts,
args.tensor_map_l1_acts_sf,
args.tensor_map_l1_weights,
args.tensor_map_l1_weights_sf,
args.tensor_map_l1_output,
args.tensor_map_l2_acts,
args.tensor_map_l2_acts_sf,
args.tensor_map_l2_weights,
args.tensor_map_l2_weights_sf
));
}
};
static void sm100_fp8_nvfp4_mega_moe(
const torch::Tensor& y,
const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf,
const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf,
const torch::Tensor& l1_weights, const torch::Tensor& l2_weights,
const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf,
const std::optional<torch::Tensor> cumulative_local_expert_recv_stats,
const std::vector<int64_t>& sym_buffer_ptrs,
const int& rank_idx, const int& num_max_tokens_per_rank,
const int& num_experts_per_rank,
const int& num_tokens, const int& num_topk,
const int& hidden, const int& intermediate_hidden,
const float& activation_clamp,
const bool& fast_math
) {
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
const auto num_experts = num_experts_per_rank * num_ranks;
const auto num_padded_sf_pool_tokens = static_cast<int>(l1_acts_sf.size(0));
// Heuristics
const auto config = get_mega_moe_config(
num_ranks, num_experts, num_experts_per_rank,
num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden, num_padded_sf_pool_tokens);
// NVFP4: kGranK=16 for group_size=16
constexpr int kGranK = 16;
// Make tensormap — weight/activation TMA descriptors are the same as MXFP4
// (E2M1 packed uint8 is the same format regardless of scale type)
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l1_acts.stride(-2)),
config.swizzle_acts_mode);
// NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
config.num_padded_sf_pool_tokens, hidden,
config.sf_block_m, kGranK,
1, 0);
const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights,
hidden, num_experts_per_rank * intermediate_hidden * 2,
config.block_k, config.load_block_n,
static_cast<int>(l1_weights.stride(-2)),
config.swizzle_weights_mode);
const auto tensor_map_l1_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_weights_sf,
intermediate_hidden * 2, hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
intermediate_hidden, config.num_max_pool_tokens,
config.block_n / 2, config.store_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode / 2);
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
intermediate_hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode);
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
config.num_padded_sf_pool_tokens, intermediate_hidden,
config.sf_block_m, kGranK,
1, 0);
const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights,
intermediate_hidden, num_experts_per_rank * hidden,
config.block_k, config.load_block_n,
static_cast<int>(l2_weights.stride(-2)),
config.swizzle_weights_mode);
const auto tensor_map_l2_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_weights_sf,
hidden, intermediate_hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
// Stats can be optional
int* cumulative_local_expert_recv_stats_ptr = nullptr;
if (cumulative_local_expert_recv_stats.has_value())
cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr<int>();
// Launch
const auto num_sms = device_runtime->get_num_sms();
const SM100FP8NVFP4MegaMoERuntime::Args args = {
.num_max_tokens_per_rank = num_max_tokens_per_rank,
.hidden = hidden, .intermediate_hidden = intermediate_hidden,
.num_experts = num_experts, .num_topk = num_topk,
.num_ranks = num_ranks,
.activation_clamp = activation_clamp,
.fast_math = fast_math,
.config = config,
.y = y.data_ptr(),
.cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr,
.num_tokens = num_tokens,
.sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx),
.tensor_map_l1_acts = tensor_map_l1_acts,
.tensor_map_l1_acts_sf = tensor_map_l1_acts_sf,
.tensor_map_l1_weights = tensor_map_l1_weights,
.tensor_map_l1_weights_sf = tensor_map_l1_weights_sf,
.tensor_map_l1_output = tensor_map_l1_output,
.tensor_map_l2_acts = tensor_map_l2_acts,
.tensor_map_l2_acts_sf = tensor_map_l2_acts_sf,
.tensor_map_l2_weights = tensor_map_l2_weights,
.tensor_map_l2_weights_sf = tensor_map_l2_weights_sf,
.launch_args = LaunchArgs(num_sms,
config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads,
config.smem_size, 2)
};
const auto code = SM100FP8NVFP4MegaMoERuntime::generate(args);
const auto runtime = compiler->build("sm100_fp8_nvfp4_mega_moe", code);
SM100FP8NVFP4MegaMoERuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -7,6 +7,7 @@
#include "apis/gemm.hpp"
#include "apis/layout.hpp"
#include "apis/mega.hpp"
#include "apis/mega_nvfp4.hpp"
#include "apis/runtime.hpp"
#ifndef TORCH_EXTENSION_NAME
@@ -24,5 +25,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
deep_gemm::gemm::register_apis(m);
deep_gemm::layout::register_apis(m);
deep_gemm::mega::register_apis(m);
deep_gemm::mega::nvfp4::register_apis(m);
deep_gemm::runtime::register_apis(m);
}

File diff suppressed because it is too large Load Diff

View File

@@ -139,6 +139,50 @@ struct SM100_MMA_MXF4_SS {
}
};
/// NVFP4 block-scaled MMA (scale_vec::4X = block16, UE4M3 scales)
struct SM100_MMA_MXF4NVF4_2x1SM_SS {
CUTLASS_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc,
uint32_t const& tmem_sfa,
uint32_t const& tmem_sfb) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
"r"(tmem_sfa), "r"(tmem_sfb));
}
};
/// NVFP4 block-scaled MMA (single CTA, scale_vec::4X = block16, UE4M3 scales)
struct SM100_MMA_MXF4NVF4_SS {
CUTLASS_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc,
uint32_t const& tmem_sfa,
uint32_t const& tmem_sfb) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
"r"(tmem_sfa), "r"(tmem_sfb));
}
};
struct SM100_MMA_F16BF16_WS_SS {
CUTLASS_DEVICE static void
fma(uint64_t const& desc_a,

View File

@@ -105,6 +105,57 @@ def transform_weights_for_mega_moe(
return l1_weights, l2_weights
def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
"""Pack NVFP4 UE4M3 block scales (float8_e4m3fn) into int32 UTCCP layout.
NVFP4 uses UE4M3 scales with group_size=16 (scale_vec::4X).
The UTCCP layout packs 4 consecutive scale bytes into each int32,
then applies the 4x32 transpose for TMA consumption.
Input: (num_experts, mn, K//16) float8_e4m3fn scales
Output: (num_experts, mn, K//64) int32 packed UTCCP-transposed scales
"""
num_groups, mn, sf_k = sf.shape
assert sf_k % 4 == 0, f"NVFP4 SF K dim must be divisible by 4, got {sf_k}"
assert mn % 128 == 0, f"MN dim must be divisible by 128, got {mn}"
# View as uint8 and pack 4 consecutive bytes into int32
sf_uint8 = sf.view(torch.uint8) # (num_groups, mn, sf_k)
# Pack: every 4 uint8 → 1 int32
packed = (sf_uint8[..., 0::4].to(torch.int32) |
(sf_uint8[..., 1::4].to(torch.int32) << 8) |
(sf_uint8[..., 2::4].to(torch.int32) << 16) |
(sf_uint8[..., 3::4].to(torch.int32) << 24)) # (num_groups, mn, sf_k//4)
# Apply UTCCP 4x32 transpose (same as MXFP4 — the transpose is determined
# by the 128-element alignment, not the scale vector size)
packed_sf_k = sf_k // 4
result = (packed.reshape(num_groups, -1, 4, 32, packed_sf_k)
.transpose(2, 3)
.reshape(num_groups, mn, packed_sf_k))
return torch.empty_like(packed).copy_(result)
def transform_nvfp4_weights_for_mega_moe(
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
"""Transform NVFP4 expert weights for the mega_moe kernel.
NVFP4 weights come as (weight, scale) where:
- weight: uint8 E2M1 packed, shape (num_experts, N, K//2)
- scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16)
The kernel expects (weight, packed_sf) where packed_sf is int32 UTCCP layout.
"""
# L1: interleave gate/up, then pack + transpose SF for UTCCP
l1_interleaved = _interleave_l1_weights(l1_weights)
l1_weights = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
# L2: only pack + transpose SF for UTCCP
l2_weights = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_weights[1]))
return l1_weights, l2_weights
def fp8_fp4_mega_moe(y: torch.Tensor,
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],
@@ -126,3 +177,32 @@ def fp8_fp4_mega_moe(y: torch.Tensor,
activation, activation_clamp,
fast_math
)
def fp8_nvfp4_mega_moe(y: torch.Tensor,
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],
sym_buffer: SymmBuffer,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
recipe: Tuple[int, int, int] = (1, 1, 16),
activation: str = 'swiglu',
activation_clamp: Optional[float] = None,
fast_math: bool = True):
"""NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X
with UE4M3 block scales (group_size=16).
Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales)
Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16.
"""
_C.fp8_nvfp4_mega_moe(
y,
l1_weights, l2_weights,
cumulative_local_expert_recv_stats,
sym_buffer.buffer,
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
sym_buffer.num_max_tokens_per_rank,
sym_buffer.num_experts, sym_buffer.num_topk,
recipe,
activation, activation_clamp,
fast_math
)