feat: full FP4 activations for mxf4nvf4 - E2M1 packed A side + UE4M3 scales

mxf4nvf4 requires BOTH A and B to be FP4 (E2M1 packed).
Changes:
- a_dtype_t: float_e4m3_t → float_e2m1_unpacksmem_t
- UMMA_K: 32 → 64 (FP4 MMA atom)
- L1 epilogue: FP8 quant → E2M1 FP4 quantization with nearest-neighbor
- L1 output SMEM: packed E2M1 (2 per byte), TMA store uint8
- TMA descriptors: adjusted for FP4 packing (K/2 bytes per row)
- SymmBuffer: uint8 activations, shape (M, K//2)
- Staging kernel: BF16 → E2M1 packed + UE4M3 block16 scales
This commit is contained in:
2026-05-11 20:29:08 +00:00
parent 2cd86ff5e7
commit b3d1aae038
4 changed files with 106 additions and 61 deletions

View File

@@ -26,10 +26,11 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// Workspace bytes
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
// NVFP4 layouts: group_size=16, so SF stride is K/16 (twice as many as MXFP4)
const auto fp8_token_layout = layout::Data(hidden);
// NVFP4 layouts: E2M1 packed (2 per byte), so token layout is K/2 bytes
// group_size=16, so SF stride is K/16 (twice as many as MXFP4)
const auto fp4_token_layout = layout::Data(hidden / 2);
const auto bf16_token_layout = layout::Data(hidden * 2);
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2);
const auto nvfp4_sf_layout = layout::Data(hidden / 16);
const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16);
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
@@ -38,7 +39,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// Input buffers
const auto input_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_tokens_per_rank,
fp4_token_layout, 1, num_max_tokens_per_rank,
workspace.get_end_ptr());
const auto input_sf_buffer = layout::Buffer(
nvfp4_sf_layout, 1, num_max_tokens_per_rank,
@@ -62,7 +63,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// L1 input buffer
const auto l1_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_pool_tokens,
fp4_token_layout, 1, num_max_pool_tokens,
input_topk_weights_buffer.get_end_ptr());
const auto l1_sf_buffer = layout::Buffer(
nvfp4_sf_layout, 1, num_max_padded_sf_pool_tokens,
@@ -73,7 +74,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// L2 input buffer
const auto l2_token_buffer = layout::Buffer(
fp8_intermediate_token_layout, 1, num_max_pool_tokens,
fp4_intermediate_token_layout, 1, num_max_pool_tokens,
l1_topk_weights_buffer.get_end_ptr());
const auto l2_sf_buffer = layout::Buffer(
nvfp4_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens,
@@ -91,10 +92,11 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// Slice function
auto slice_input_buffers = [=](const torch::Tensor& buffer) {
// NVFP4: E2M1 packed activations (2 per byte), K/2 bytes per token
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
{num_max_tokens_per_rank, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
{num_max_tokens_per_rank, hidden / 2},
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
auto x_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
@@ -108,20 +110,22 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
// NVFP4: L1 output acts are E2M1 packed, intermediate_hidden/2 bytes
auto l1_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
{num_max_pool_tokens, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
{num_max_pool_tokens, hidden / 2},
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 L1 SF: M-major, K/64 int32
auto l1_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, hidden / 64},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
// NVFP4: L2 acts are E2M1 packed, intermediate_hidden/2 bytes
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
{num_max_pool_tokens, intermediate_hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
{num_max_pool_tokens, intermediate_hidden / 2},
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 L2 SF: M-major, K/64 int32
auto l2_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),

View File

@@ -136,11 +136,12 @@ static void sm100_fp8_nvfp4_mega_moe(
// Make tensormap — weight/activation TMA descriptors are the same as MXFP4
// (E2M1 packed uint8 is the same format regardless of scale type)
// NVFP4: activations are E2M1 packed uint8, so K-dim is hidden/2 bytes
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
hidden / 2, config.num_max_pool_tokens,
config.block_k / 2, config.load_block_m,
static_cast<int>(l1_acts.stride(-2)),
config.swizzle_acts_mode);
config.swizzle_acts_mode / 2);
// NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
config.num_padded_sf_pool_tokens, hidden,
@@ -155,16 +156,18 @@ static void sm100_fp8_nvfp4_mega_moe(
intermediate_hidden * 2, hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
// NVFP4: L1 output is E2M1 packed, intermediate_hidden/2 bytes per row
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
intermediate_hidden, config.num_max_pool_tokens,
config.block_n / 2, config.store_block_m,
intermediate_hidden / 2, config.num_max_pool_tokens,
config.block_n / 4, config.store_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode / 2);
config.swizzle_acts_mode / 4);
// NVFP4: L2 activations are E2M1 packed, intermediate_hidden/2 bytes per row
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
intermediate_hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
intermediate_hidden / 2, config.num_max_pool_tokens,
config.block_k / 2, config.load_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode);
config.swizzle_acts_mode / 2);
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
config.num_padded_sf_pool_tokens, intermediate_hidden,
config.sf_block_m, kGranK,

View File

@@ -95,12 +95,14 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk);
// Token and buffer layouts
constexpr auto fp8_token_layout = layout::Data(kHidden);
// NVFP4: activations are E2M1 packed (2 per byte), so K/2 bytes per token
constexpr auto fp4_token_layout = layout::Data(kHidden / 2);
constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16));
constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden);
// NVFP4 intermediate: same E2M1 packing
constexpr auto fp4_intermediate_token_layout = layout::Data(kIntermediateHidden / 2);
// NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4)
constexpr auto fp8_sf_layout = layout::Data(kHidden / 16);
constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16);
constexpr auto nvfp4_sf_layout = layout::Data(kHidden / 16);
constexpr auto nvfp4_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16);
constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false);
constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false);
constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
@@ -164,8 +166,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
);
// Data types
// NOTES: activations are FP8 (e4m3), weights are FP4 (e2m1)
using a_dtype_t = cutlass::float_e4m3_t;
// NVFP4: mxf4nvf4 requires BOTH A and B to be FP4 (E2M1 packed)
using a_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t;
using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t;
// MMA configs
@@ -173,7 +175,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
constexpr uint32_t LAYOUT_AD_M = 128;
constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2;
constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB
constexpr uint32_t UMMA_K = 32;
constexpr uint32_t UMMA_K = 64; // FP4: 64 elements per MMA atom (was 32 for FP8)
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N;
DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M");
@@ -181,8 +183,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
// Swizzle configs
constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t);
constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t);
// NVFP4: float_e2m1_unpacksmem_t uses 1 byte per element in SMEM (unpacked)
// Same byte stride as FP8 — the TMA hardware unpacks 2 E2M1 values per global byte
constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); // 128
constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); // 128
constexpr uint32_t kSwizzleCDMode = 128;
DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N");
@@ -200,13 +204,16 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
constexpr uint32_t SMEM_EXPERT_COUNT_SIZE =
math::constexpr_align<uint32_t>(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment);
constexpr uint32_t SMEM_SEND_BUFFER_SIZE =
math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment);
math::constexpr_align(fp4_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment);
// NVFP4: float_e2m1_unpacksmem_t uses 1 byte per element in SMEM (unpacked)
// TMA unpacks 2 E2M1 per global byte into 1 byte per element in SMEM
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
// L1 output: E2M1 packed FP4 (2 per byte)
constexpr uint32_t SMEM_CD_L1_SIZE =
kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages;
kNumEpilogueWarpgroups * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) * sizeof(uint8_t) * kNumTMAStoreStages;
constexpr uint32_t SMEM_CD_L2_SIZE =
kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16);
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE;
@@ -1051,7 +1058,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
ptx::tma_store_wait<kNumTMAStoreStages - 1>();
ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
// Cast to FP8 E4M3 and store into shared memory
// Quantize to E2M1 FP4 and store into shared memory
// NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1 packed
// Scale for FP4: scale = amax / 6.0 (E2M1 max value)
// UE4M3 scale already computed below (same as FP8 case but using /6)
#pragma unroll
for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) {
// Reduce amax
@@ -1060,47 +1070,73 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x);
amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y);
// Calculate SF
// Calculate SF for E2M1: scale = amax / 6.0
// UE4M3 format (same computation as FP8 but different scale base)
float2 sf, sf_inv;
math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv);
// Use amax/6.0 as the scale (E2M1 max = 6)
sf.x = fmaxf(amax_values[i].x / 6.0f, 1e-8f);
sf.y = fmaxf(amax_values[i].y / 6.0f, 1e-8f);
sf_inv.x = 1.0f / sf.x;
sf_inv.y = 1.0f / sf.y;
// Cast
const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y));
// E2M1 FP4 quantization: find nearest from [0, 0.5, 1, 1.5, 2, 3, 4, 6]
// Process 4 BF16 values -> 4 E2M1 4-bit values -> pack into 2 bytes
auto quant_e2m1 = [](float v, float scale_inv) -> uint8_t {
float q = v * scale_inv;
q = fmaxf(-6.0f, fminf(6.0f, q));
uint8_t sign = (q < 0.0f) ? 1 : 0;
float aq = fabsf(q);
// Nearest E2M1 index
uint8_t idx;
if (aq < 0.25f) idx = 0;
else if (aq < 0.75f) idx = 1;
else if (aq < 1.25f) idx = 2;
else if (aq < 1.75f) idx = 3;
else if (aq < 2.5f) idx = 4;
else if (aq < 3.5f) idx = 5;
else if (aq < 5.0f) idx = 6;
else idx = 7;
return (sign << 3) | idx; // 4-bit packed
};
// STSM
uint32_t row = lane_idx;
// Quantize 4 BF16 values -> 4 E2M1 nibbles
float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
uint8_t e0 = quant_e2m1(upper.x, 1.0f);
uint8_t e1 = quant_e2m1(upper.y, 1.0f);
uint8_t e2 = quant_e2m1(lower.x, 1.0f);
uint8_t e3 = quant_e2m1(lower.y, 1.0f);
// Pack 2 nibbles per byte: (e1<<4)|e0, (e3<<4)|e2
uint8_t b0 = (e1 << 4) | e0;
uint8_t b1 = (e3 << 4) | e2;
// Store packed FP4 bytes to SMEM (row-major, L1_OUT_BLOCK_N/2 bytes per row)
uint32_t row = lane_idx; // lane maps to row within ATOM_M
uint32_t col = warp_idx_in_wg;
const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N
+ i * ATOM_M * L1_OUT_BLOCK_N
+ row * L1_OUT_BLOCK_N
+ (col ^ (row / 2)) * kNumBankGroupBytes;
ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr);
// SMEM layout: simple row-major, L1_OUT_BLOCK_N/2 bytes per row
// Each STSM atom wrote 4 FP8 values = 4 bytes. Now we write 2 bytes (4 FP4 values packed).
// Column offset: each warp handles 4 bytes worth of N (was 4 FP8 = 4 bytes, now 4 FP4 = 2 bytes)
const auto smem_base = smem_cd[tma_stage_idx]
+ epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2)
+ i * ATOM_M * (L1_OUT_BLOCK_N / 2);
// Bank-conflict-free addressing: interleave with 4-byte offset per warp
uint32_t byte_col = col * 2; // 2 bytes per warp per row
auto smem_ptr = smem_base + row * (L1_OUT_BLOCK_N / 2) + byte_col;
// Write 2 packed FP4 bytes
smem_ptr[0] = b0;
smem_ptr[1] = b1;
// Store SF to `l2_sf_buffer` as UE8M0 (MN-major layout)
// Only one warp per pair writes (both hold the same SF after cross-warp reduce)
// Each lane < 4 holds SF for 2 rows (sf.x and sf.y)
// Store SF to `l2_sf_buffer` as UE4M3 (MN-major layout)
if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) {
const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2;
const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4;
const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t);
const auto sf_base_ptr = l2_sf_buffer.get_base_ptr<uint8_t>();
// NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4
// NOTES: originally there was:
// - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2
// - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)`
// We find out that
// 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside
// 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside
// This reduce the number of computation instructions.
const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M;
__builtin_assume(token_base_idx < BLOCK_M);
const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M
+ m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4;
const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast<uint32_t>(sizeof(uint32_t)) + byte_idx;
// NVFP4: convert float scale to UE4M3 format
// UE4M3: sign=0 + 4 exp + 3 mantissa, max=448
auto to_ue4m3 = [](float v) -> uint8_t {
v = fmaxf(0.0f, fminf(v, 448.0f));
cutlass::float_e4m3_t e4m3_val = cutlass::float_e4m3_t(v);
@@ -1115,11 +1151,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// Issue TMA store after all atoms in this store block
if (warp_idx_in_wg == 0 and cute::elect_one_sync()) {
uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N;
uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N / 2; // FP4: byte offset = element offset / 2
cute::tma_store_fence();
cute::SM90_TMA_STORE_2D::copy(
&tensor_map_l1_output,
smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N,
smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2),
out_n_idx,
m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M);
cute::tma_store_arrive();

View File

@@ -312,7 +312,9 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor,
"""NVFP4 mega MoE: uses kind::mxf4nvf4.block_scale.scale_vec::4X
with UE4M3 block scales (group_size=16).
Both activations AND weights are E2M1 packed (FP4×FP4).
Weight format: (uint8 E2M1 packed, int32 packed UTCCP UE4M3 scales)
Activation format: E2M1 packed uint8 + UE4M3 scales (computed by staging kernel)
Recipe: (1, 1, 16) — kGranK=16 for NVFP4 group_size=16.
"""
_C.fp8_nvfp4_mega_moe(