fix: L1 output uses unpacked E2M1 (1 byte/element) like FP8

- float_e2m1_unpacksmem_t: sizeof=1, SMEM is 1 byte/element (not packed)
- TMA load unpacks 2 E2M1/global-byte → 2 SMEM bytes
- UMMA reads unpacked SMEM, packs internally for mxf4nvf4
- L1→L2 handoff: unpacked format (same byte count as FP8)
- Epilogue: 4 E2M1 bytes per uint32 STSM atom, same as FP8
- Dispatch TMA: kHidden bytes (unpacked), not kHidden/2
- Added static_assert on sizeof(a_dtype_t) and sizeof(b_dtype_t)
- Note: no bandwidth savings at L1→L2 boundary for v1
This commit is contained in:
2026-05-11 21:27:35 +00:00
parent 091b974736
commit 0ac73a82f9
3 changed files with 52 additions and 46 deletions

View File

@@ -28,9 +28,9 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// NVFP4 layouts: E2M1 packed (2 per byte), so token layout is K/2 bytes
// group_size=16, so SF stride is K/16 (twice as many as MXFP4)
const auto fp4_token_layout = layout::Data(hidden / 2);
const auto fp4_token_layout = layout::Data(hidden);
const auto bf16_token_layout = layout::Data(hidden * 2);
const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2);
const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden);
const auto nvfp4_sf_layout = layout::Data(hidden / 16);
const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16);
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
@@ -95,7 +95,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// NVFP4: E2M1 packed activations (2 per byte), K/2 bytes per token
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
{num_max_tokens_per_rank, hidden / 2},
{num_max_tokens_per_rank, hidden},
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
auto x_sf = torch::from_blob(
@@ -113,7 +113,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// NVFP4: L1 output acts are E2M1 packed, intermediate_hidden/2 bytes
auto l1_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
{num_max_pool_tokens, hidden / 2},
{num_max_pool_tokens, hidden},
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 L1 SF: M-major, K/64 int32
auto l1_acts_sf = torch::from_blob(
@@ -124,7 +124,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
// NVFP4: L2 acts are E2M1 packed, intermediate_hidden/2 bytes
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
{num_max_pool_tokens, intermediate_hidden / 2},
{num_max_pool_tokens, intermediate_hidden},
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
// NVFP4 L2 SF: M-major, K/64 int32
auto l2_acts_sf = torch::from_blob(

View File

@@ -136,12 +136,12 @@ static void sm100_fp8_nvfp4_mega_moe(
// Make tensormap — weight/activation TMA descriptors are the same as MXFP4
// (E2M1 packed uint8 is the same format regardless of scale type)
// NVFP4: activations are E2M1 packed uint8, so K-dim is hidden/2 bytes
// L1 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
hidden / 2, config.num_max_pool_tokens,
config.block_k / 2, config.load_block_m,
hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l1_acts.stride(-2)),
config.swizzle_acts_mode / 2);
config.swizzle_acts_mode);
// NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
config.num_padded_sf_pool_tokens, hidden,
@@ -156,18 +156,18 @@ static void sm100_fp8_nvfp4_mega_moe(
intermediate_hidden * 2, hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
// NVFP4: L1 output is E2M1 packed, intermediate_hidden/2 bytes per row
// L1 output: unpacked E2M1 (1 byte/element), same dimensions as FP8
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
intermediate_hidden / 2, config.num_max_pool_tokens,
config.block_n / 4, config.store_block_m,
intermediate_hidden, config.num_max_pool_tokens,
config.block_n / 2, config.store_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode / 4);
// NVFP4: L2 activations are E2M1 packed, intermediate_hidden/2 bytes per row
config.swizzle_acts_mode / 2);
// L2 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
intermediate_hidden / 2, config.num_max_pool_tokens,
config.block_k / 2, config.load_block_m,
intermediate_hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode / 2);
config.swizzle_acts_mode);
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
config.num_padded_sf_pool_tokens, intermediate_hidden,
config.sf_block_m, kGranK,

View File

@@ -96,10 +96,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// Token and buffer layouts
// NVFP4: activations are E2M1 packed (2 per byte), so K/2 bytes per token
constexpr auto fp4_token_layout = layout::Data(kHidden / 2);
constexpr auto fp4_token_layout = layout::Data(kHidden);
constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16));
// NVFP4 intermediate: same E2M1 packing
constexpr auto fp4_intermediate_token_layout = layout::Data(kIntermediateHidden / 2);
constexpr auto fp4_intermediate_token_layout = layout::Data(kIntermediateHidden);
// NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4)
constexpr auto nvfp4_sf_layout = layout::Data(kHidden / 16);
constexpr auto nvfp4_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16);
@@ -169,6 +169,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// NVFP4: mxf4nvf4 requires BOTH A and B to be FP4 (E2M1 packed)
using a_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t;
using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t;
// Verify SMEM element sizes: unpacksmem = 1 byte/element (not 0.5)
// The TMA load unpacks 2 E2M1/global-byte → 2 SMEM bytes.
// UMMA reads unpacked SMEM and packs internally for mxf4nvf4.
static_assert(sizeof(a_dtype_t) == 1, "a_dtype_t (E2M1 unpacked SMEM) must be 1 byte/element");
static_assert(sizeof(b_dtype_t) == 1, "b_dtype_t (E2M1 unpacked SMEM) must be 1 byte/element");
// MMA configs
// NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major
@@ -211,9 +216,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
// L1 output: E2M1 packed FP4 (2 per byte)
// L1 output: unpacked E2M1 in SMEM (1 byte/element, same as FP8 layout)
// TMA store writes to global as unpacked; L2 TMA load reads unpacked into SMEM
// This avoids the packed/unpacked mismatch at the L1→L2 boundary
constexpr uint32_t SMEM_CD_L1_SIZE =
kNumEpilogueWarpgroups * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) * sizeof(uint8_t) * kNumTMAStoreStages;
kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(uint8_t) * kNumTMAStoreStages;
constexpr uint32_t SMEM_CD_L2_SIZE =
kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16);
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE;
@@ -566,7 +573,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
pull_buffer.get_base_ptr(),
sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(),
current_rank_in_expert_idx),
pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes
pull_mbarrier, kHidden); // NVFP4: unpacked E2M1, same byte count as FP8
}
__syncwarp();
@@ -598,7 +605,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
*l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr<float>() = weight;
// Wait for TMA token load to complete
ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes
ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); // NVFP4: unpacked E2M1, same byte count as FP8
ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase);
// Store token to local L1 buffer via TMA
@@ -1059,10 +1066,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
ptx::tma_store_wait<kNumTMAStoreStages - 1>();
ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx);
// Quantize to E2M1 FP4 and STSM into shared memory
// NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1 packed
// Pack 8 E2M1 nibbles into uint32 for one STSM write (4 bytes = 8 FP4 values)
// Keep XOR swizzle for bank-conflict-free SMEM layout that TMA expects
// Quantize to E2M1 FP4 and STSM into shared memory (unpacked, 1 byte/element)
// NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1
// SMEM layout: unpacked (1 byte per E2M1 element), same byte count as FP8
// 4 BF16 → 4 E2M1 bytes → 1 STSM uint32 write, with XOR swizzle
#pragma unroll
for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) {
// Reduce amax
@@ -1078,8 +1085,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
sf_inv.x = 1.0f / sf.x;
sf_inv.y = 1.0f / sf.y;
// E2M1 FP4 quantization helper
auto quant_e2m1 = [](float v) -> uint8_t {
// E2M1 FP4 quantization: find nearest from [0, 0.5, 1, 1.5, 2, 3, 4, 6]
// Store as 4-bit in low nibble of each byte (unpacked SMEM: 1 byte/element)
auto quant_e2m1_byte = [](float v) -> uint8_t {
v = fmaxf(-6.0f, fminf(6.0f, v));
uint8_t sign = (v < 0.0f) ? 1 : 0;
float aq = fabsf(v);
@@ -1092,29 +1100,27 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
else if (aq < 3.5f) idx = 5;
else if (aq < 5.0f) idx = 6;
else idx = 7;
return (sign << 3) | idx;
return (sign << 3) | idx; // 4-bit E2M1 in low nibble
};
// Quantize 4 BF16 values -> 4 E2M1 nibbles -> pack into 2 bytes
// Quantize 4 BF16 values 4 E2M1 bytes pack into uint32 for STSM
const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv);
const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv);
uint8_t e0 = quant_e2m1(upper.x);
uint8_t e1 = quant_e2m1(upper.y);
uint8_t e2 = quant_e2m1(lower.x);
uint8_t e3 = quant_e2m1(lower.y);
// Pack 4 nibbles into uint32 (2 bytes): (e1<<4)|e0, (e3<<4)|e2
// Note: only fills 2 of the 4 STSM bytes. The other 2 bytes will be
// written by the adjacent warp (same XOR swizzle column, different row range).
uint32_t packed = (uint32_t)((e1 << 4) | e0)
| ((uint32_t)((e3 << 4) | e2) << 8);
uint8_t e0 = quant_e2m1_byte(upper.x);
uint8_t e1 = quant_e2m1_byte(upper.y);
uint8_t e2 = quant_e2m1_byte(lower.x);
uint8_t e3 = quant_e2m1_byte(lower.y);
// Pack 4 bytes into uint32 for one STSM atom
uint32_t packed = (uint32_t)e0 | ((uint32_t)e1 << 8)
| ((uint32_t)e2 << 16) | ((uint32_t)e3 << 24);
// STSM with XOR swizzle (same pattern as FP8, but halved byte stride)
// STSM with XOR swizzle (same layout as FP8, 1 byte/element)
uint32_t row = lane_idx;
uint32_t col = warp_idx_in_wg;
const auto smem_ptr = smem_cd[tma_stage_idx]
+ epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2)
+ i * ATOM_M * (L1_OUT_BLOCK_N / 2)
+ row * (L1_OUT_BLOCK_N / 2)
+ epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N
+ i * ATOM_M * L1_OUT_BLOCK_N
+ row * L1_OUT_BLOCK_N
+ (col ^ (row / 2)) * kNumBankGroupBytes;
ptx::SM100_U32x1_STSM_T<uint32_t>::copy(packed, smem_ptr);
@@ -1143,11 +1149,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// Issue TMA store after all atoms in this store block
if (warp_idx_in_wg == 0 and cute::elect_one_sync()) {
uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N / 2; // FP4: byte offset = element offset / 2
uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; // unpacked: 1 byte/element
cute::tma_store_fence();
cute::SM90_TMA_STORE_2D::copy(
&tensor_map_l1_output,
smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2),
smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N),
out_n_idx,
m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M);
cute::tma_store_arrive();