From 0ac73a82f9d336bbce14262de82b12422e0fa4ae Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 21:27:35 +0000 Subject: [PATCH] fix: L1 output uses unpacked E2M1 (1 byte/element) like FP8 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - float_e2m1_unpacksmem_t: sizeof=1, SMEM is 1 byte/element (not packed) - TMA load unpacks 2 E2M1/global-byte → 2 SMEM bytes - UMMA reads unpacked SMEM, packs internally for mxf4nvf4 - L1→L2 handoff: unpacked format (same byte count as FP8) - Epilogue: 4 E2M1 bytes per uint32 STSM atom, same as FP8 - Dispatch TMA: kHidden bytes (unpacked), not kHidden/2 - Added static_assert on sizeof(a_dtype_t) and sizeof(b_dtype_t) - Note: no bandwidth savings at L1→L2 boundary for v1 --- csrc/apis/mega_nvfp4.hpp | 10 +-- .../impls/sm100_fp8_nvfp4_mega_moe.hpp | 24 +++---- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 64 ++++++++++--------- 3 files changed, 52 insertions(+), 46 deletions(-) diff --git a/csrc/apis/mega_nvfp4.hpp b/csrc/apis/mega_nvfp4.hpp index fd9fff0..e8812f6 100644 --- a/csrc/apis/mega_nvfp4.hpp +++ b/csrc/apis/mega_nvfp4.hpp @@ -28,9 +28,9 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4 layouts: E2M1 packed (2 per byte), so token layout is K/2 bytes // group_size=16, so SF stride is K/16 (twice as many as MXFP4) - const auto fp4_token_layout = layout::Data(hidden / 2); + const auto fp4_token_layout = layout::Data(hidden); const auto bf16_token_layout = layout::Data(hidden * 2); - const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2); + const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden); const auto nvfp4_sf_layout = layout::Data(hidden / 16); const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16); const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); @@ -95,7 +95,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4: E2M1 packed activations (2 per byte), K/2 bytes per token auto x = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_token_buffer.base)), - {num_max_tokens_per_rank, hidden / 2}, + {num_max_tokens_per_rank, hidden}, torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); // NVFP4 SF: K/16 bytes per token, packed as K/64 int32 auto x_sf = torch::from_blob( @@ -113,7 +113,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4: L1 output acts are E2M1 packed, intermediate_hidden/2 bytes auto l1_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_token_buffer.base)), - {num_max_pool_tokens, hidden / 2}, + {num_max_pool_tokens, hidden}, torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); // NVFP4 L1 SF: M-major, K/64 int32 auto l1_acts_sf = torch::from_blob( @@ -124,7 +124,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4: L2 acts are E2M1 packed, intermediate_hidden/2 bytes auto l2_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), - {num_max_pool_tokens, intermediate_hidden / 2}, + {num_max_pool_tokens, intermediate_hidden}, torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); // NVFP4 L2 SF: M-major, K/64 int32 auto l2_acts_sf = torch::from_blob( diff --git a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp index ecdbabf..649a186 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp @@ -136,12 +136,12 @@ static void sm100_fp8_nvfp4_mega_moe( // Make tensormap — weight/activation TMA descriptors are the same as MXFP4 // (E2M1 packed uint8 is the same format regardless of scale type) - // NVFP4: activations are E2M1 packed uint8, so K-dim is hidden/2 bytes + // L1 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8 const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts, - hidden / 2, config.num_max_pool_tokens, - config.block_k / 2, config.load_block_m, + hidden, config.num_max_pool_tokens, + config.block_k, config.load_block_m, static_cast(l1_acts.stride(-2)), - config.swizzle_acts_mode / 2); + config.swizzle_acts_mode); // NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32 const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf, config.num_padded_sf_pool_tokens, hidden, @@ -156,18 +156,18 @@ static void sm100_fp8_nvfp4_mega_moe( intermediate_hidden * 2, hidden, config.block_n, kGranK, num_experts_per_rank, 0); - // NVFP4: L1 output is E2M1 packed, intermediate_hidden/2 bytes per row + // L1 output: unpacked E2M1 (1 byte/element), same dimensions as FP8 const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, - intermediate_hidden / 2, config.num_max_pool_tokens, - config.block_n / 4, config.store_block_m, + intermediate_hidden, config.num_max_pool_tokens, + config.block_n / 2, config.store_block_m, static_cast(l2_acts.stride(-2)), - config.swizzle_acts_mode / 4); - // NVFP4: L2 activations are E2M1 packed, intermediate_hidden/2 bytes per row + config.swizzle_acts_mode / 2); + // L2 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8 const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts, - intermediate_hidden / 2, config.num_max_pool_tokens, - config.block_k / 2, config.load_block_m, + intermediate_hidden, config.num_max_pool_tokens, + config.block_k, config.load_block_m, static_cast(l2_acts.stride(-2)), - config.swizzle_acts_mode / 2); + config.swizzle_acts_mode); const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf, config.num_padded_sf_pool_tokens, intermediate_hidden, config.sf_block_m, kGranK, diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 66d44b7..9703a1c 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -96,10 +96,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Token and buffer layouts // NVFP4: activations are E2M1 packed (2 per byte), so K/2 bytes per token - constexpr auto fp4_token_layout = layout::Data(kHidden / 2); + constexpr auto fp4_token_layout = layout::Data(kHidden); constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); // NVFP4 intermediate: same E2M1 packing - constexpr auto fp4_intermediate_token_layout = layout::Data(kIntermediateHidden / 2); + constexpr auto fp4_intermediate_token_layout = layout::Data(kIntermediateHidden); // NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4) constexpr auto nvfp4_sf_layout = layout::Data(kHidden / 16); constexpr auto nvfp4_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16); @@ -169,6 +169,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // NVFP4: mxf4nvf4 requires BOTH A and B to be FP4 (E2M1 packed) using a_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; + // Verify SMEM element sizes: unpacksmem = 1 byte/element (not 0.5) + // The TMA load unpacks 2 E2M1/global-byte → 2 SMEM bytes. + // UMMA reads unpacked SMEM and packs internally for mxf4nvf4. + static_assert(sizeof(a_dtype_t) == 1, "a_dtype_t (E2M1 unpacked SMEM) must be 1 byte/element"); + static_assert(sizeof(b_dtype_t) == 1, "b_dtype_t (E2M1 unpacked SMEM) must be 1 byte/element"); // MMA configs // NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major @@ -211,9 +216,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); - // L1 output: E2M1 packed FP4 (2 per byte) + // L1 output: unpacked E2M1 in SMEM (1 byte/element, same as FP8 layout) + // TMA store writes to global as unpacked; L2 TMA load reads unpacked into SMEM + // This avoids the packed/unpacked mismatch at the L1→L2 boundary constexpr uint32_t SMEM_CD_L1_SIZE = - kNumEpilogueWarpgroups * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) * sizeof(uint8_t) * kNumTMAStoreStages; + kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(uint8_t) * kNumTMAStoreStages; constexpr uint32_t SMEM_CD_L2_SIZE = kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; @@ -566,7 +573,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, pull_buffer.get_base_ptr(), sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), current_rank_in_expert_idx), - pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes + pull_mbarrier, kHidden); // NVFP4: unpacked E2M1, same byte count as FP8 } __syncwarp(); @@ -598,7 +605,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; // Wait for TMA token load to complete - ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden / 2); // NVFP4: E2M1 packed, half the bytes + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); // NVFP4: unpacked E2M1, same byte count as FP8 ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); // Store token to local L1 buffer via TMA @@ -1059,10 +1066,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, ptx::tma_store_wait(); ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); - // Quantize to E2M1 FP4 and STSM into shared memory - // NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1 packed - // Pack 8 E2M1 nibbles into uint32 for one STSM write (4 bytes = 8 FP4 values) - // Keep XOR swizzle for bank-conflict-free SMEM layout that TMA expects + // Quantize to E2M1 FP4 and STSM into shared memory (unpacked, 1 byte/element) + // NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1 + // SMEM layout: unpacked (1 byte per E2M1 element), same byte count as FP8 + // 4 BF16 → 4 E2M1 bytes → 1 STSM uint32 write, with XOR swizzle #pragma unroll for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { // Reduce amax @@ -1078,8 +1085,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, sf_inv.x = 1.0f / sf.x; sf_inv.y = 1.0f / sf.y; - // E2M1 FP4 quantization helper - auto quant_e2m1 = [](float v) -> uint8_t { + // E2M1 FP4 quantization: find nearest from [0, 0.5, 1, 1.5, 2, 3, 4, 6] + // Store as 4-bit in low nibble of each byte (unpacked SMEM: 1 byte/element) + auto quant_e2m1_byte = [](float v) -> uint8_t { v = fmaxf(-6.0f, fminf(6.0f, v)); uint8_t sign = (v < 0.0f) ? 1 : 0; float aq = fabsf(v); @@ -1092,29 +1100,27 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, else if (aq < 3.5f) idx = 5; else if (aq < 5.0f) idx = 6; else idx = 7; - return (sign << 3) | idx; + return (sign << 3) | idx; // 4-bit E2M1 in low nibble }; - // Quantize 4 BF16 values -> 4 E2M1 nibbles -> pack into 2 bytes + // Quantize 4 BF16 values → 4 E2M1 bytes → pack into uint32 for STSM const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); - uint8_t e0 = quant_e2m1(upper.x); - uint8_t e1 = quant_e2m1(upper.y); - uint8_t e2 = quant_e2m1(lower.x); - uint8_t e3 = quant_e2m1(lower.y); - // Pack 4 nibbles into uint32 (2 bytes): (e1<<4)|e0, (e3<<4)|e2 - // Note: only fills 2 of the 4 STSM bytes. The other 2 bytes will be - // written by the adjacent warp (same XOR swizzle column, different row range). - uint32_t packed = (uint32_t)((e1 << 4) | e0) - | ((uint32_t)((e3 << 4) | e2) << 8); + uint8_t e0 = quant_e2m1_byte(upper.x); + uint8_t e1 = quant_e2m1_byte(upper.y); + uint8_t e2 = quant_e2m1_byte(lower.x); + uint8_t e3 = quant_e2m1_byte(lower.y); + // Pack 4 bytes into uint32 for one STSM atom + uint32_t packed = (uint32_t)e0 | ((uint32_t)e1 << 8) + | ((uint32_t)e2 << 16) | ((uint32_t)e3 << 24); - // STSM with XOR swizzle (same pattern as FP8, but halved byte stride) + // STSM with XOR swizzle (same layout as FP8, 1 byte/element) uint32_t row = lane_idx; uint32_t col = warp_idx_in_wg; const auto smem_ptr = smem_cd[tma_stage_idx] - + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) - + i * ATOM_M * (L1_OUT_BLOCK_N / 2) - + row * (L1_OUT_BLOCK_N / 2) + + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N + + i * ATOM_M * L1_OUT_BLOCK_N + + row * L1_OUT_BLOCK_N + (col ^ (row / 2)) * kNumBankGroupBytes; ptx::SM100_U32x1_STSM_T::copy(packed, smem_ptr); @@ -1143,11 +1149,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Issue TMA store after all atoms in this store block if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { - uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N / 2; // FP4: byte offset = element offset / 2 + uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; // unpacked: 1 byte/element cute::tma_store_fence(); cute::SM90_TMA_STORE_2D::copy( &tensor_map_l1_output, - smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2), + smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N), out_n_idx, m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M); cute::tma_store_arrive();