From 30d72e7ef560322cc5f81e6447faa7a237e446cd Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 21:59:21 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20packed=20FP4=20for=20mxf4nvf4=20?= =?UTF-8?q?=E2=80=94=20correct=20SMEM=20layout,=20UMMA=20descriptors,=20L1?= =?UTF-8?q?=20epilogue?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Key changes: - a_dtype_t/b_dtype_t: float_e2m1_t (packed 4-bit) with sizeof_bits_v==4 assert - kSwizzleAMode/BMode: BLOCK_K/2 (64 bytes packed, not 128 unpacked) - SMEM sizes: LOAD_BLOCK_M * BLOCK_K / 2 (packed byte count) - Token layouts: kHidden/2, kIntermediateHidden/2 (packed bytes) - TMA loads: BLOCK_K/2 inner dim, uint8_t, byte offsets k_block_idx*(BLOCK_K/2) - UMMA descriptors: BLOCK_K/2 template param, uint8_t dtype, UMMA_K/2 advance - L1 epilogue: dropped STSM, direct st.shared.u16 with packed nibbles, no swizzle (v1) - Pybind buffer sizes: hidden/2, intermediate_hidden/2 with packed tensor shapes - Host TMA descriptors: hidden/2 K-dims, block_k/2 inner, fp4_unpacked_smem=false - L1 output TMA: block_n/4 inner, no swizzle (CU_TENSOR_MAP_SWIZZLE_NONE) --- csrc/apis/mega_nvfp4.hpp | 12 +- .../impls/sm100_fp8_nvfp4_mega_moe.hpp | 31 ++-- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 158 +++++++++--------- 3 files changed, 105 insertions(+), 96 deletions(-) diff --git a/csrc/apis/mega_nvfp4.hpp b/csrc/apis/mega_nvfp4.hpp index e8812f6..e7fa81f 100644 --- a/csrc/apis/mega_nvfp4.hpp +++ b/csrc/apis/mega_nvfp4.hpp @@ -28,9 +28,9 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4 layouts: E2M1 packed (2 per byte), so token layout is K/2 bytes // group_size=16, so SF stride is K/16 (twice as many as MXFP4) - const auto fp4_token_layout = layout::Data(hidden); + const auto fp4_token_layout = layout::Data(hidden / 2); const auto bf16_token_layout = layout::Data(hidden * 2); - const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden); + const auto fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2); const auto nvfp4_sf_layout = layout::Data(hidden / 16); const auto nvfp4_intermediate_sf_layout = layout::Data(intermediate_hidden / 16); const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); @@ -95,7 +95,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4: E2M1 packed activations (2 per byte), K/2 bytes per token auto x = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_token_buffer.base)), - {num_max_tokens_per_rank, hidden}, + {num_max_tokens_per_rank, hidden / 2}, // packed: hidden elements = hidden/2 bytes torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); // NVFP4 SF: K/16 bytes per token, packed as K/64 int32 auto x_sf = torch::from_blob( @@ -110,10 +110,10 @@ get_symm_buffer_size_for_nvfp4_mega_moe( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_weights_buffer.base)), {num_max_tokens_per_rank, num_topk}, torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device())); - // NVFP4: L1 output acts are E2M1 packed, intermediate_hidden/2 bytes + // NVFP4: L1 output acts are E2M1 packed, hidden/2 bytes auto l1_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_token_buffer.base)), - {num_max_pool_tokens, hidden}, + {num_max_pool_tokens, hidden / 2}, // packed: hidden elements = hidden/2 bytes torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); // NVFP4 L1 SF: M-major, K/64 int32 auto l1_acts_sf = torch::from_blob( @@ -124,7 +124,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( // NVFP4: L2 acts are E2M1 packed, intermediate_hidden/2 bytes auto l2_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), - {num_max_pool_tokens, intermediate_hidden}, + {num_max_pool_tokens, intermediate_hidden / 2}, // packed: elements/2 bytes torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); // NVFP4 L2 SF: M-major, K/64 int32 auto l2_acts_sf = torch::from_blob( diff --git a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp index 649a186..44038dc 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp @@ -134,14 +134,14 @@ static void sm100_fp8_nvfp4_mega_moe( // NVFP4: kGranK=16 for group_size=16 constexpr int kGranK = 16; - // Make tensormap — weight/activation TMA descriptors are the same as MXFP4 - // (E2M1 packed uint8 is the same format regardless of scale type) - // L1 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8 + // Make tensormap — NVFP4 packed E2M1 (2 per byte), so K-dim is hidden/2 bytes + // L1 activations: packed E2M1 (4 bits/element), K-dim = hidden/2 const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts, - hidden, config.num_max_pool_tokens, - config.block_k, config.load_block_m, + hidden / 2, config.num_max_pool_tokens, + config.block_k / 2, config.load_block_m, static_cast(l1_acts.stride(-2)), - config.swizzle_acts_mode); + config.swizzle_acts_mode / 2, + 0, false, false); // fp4_unpacked_smem=false // NVFP4 SF TMA: kGranK=16, so SF K-dim is hidden/16, packed as hidden/64 int32 const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf, config.num_padded_sf_pool_tokens, hidden, @@ -156,18 +156,21 @@ static void sm100_fp8_nvfp4_mega_moe( intermediate_hidden * 2, hidden, config.block_n, kGranK, num_experts_per_rank, 0); - // L1 output: unpacked E2M1 (1 byte/element), same dimensions as FP8 + // L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/4 bytes, no swizzle (v1) const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, - intermediate_hidden, config.num_max_pool_tokens, - config.block_n / 2, config.store_block_m, + intermediate_hidden / 2, config.num_max_pool_tokens, + config.block_n / 4, config.store_block_m, static_cast(l2_acts.stride(-2)), - config.swizzle_acts_mode / 2); - // L2 activations: unpacked E2M1 (1 byte/element), same dimensions as FP8 + 0, 0, // no swizzle + false, // allow_tf32 + false); // fp4_unpacked_smem=false (packed!) + // L2 activations: packed E2M1, K-dim = intermediate_hidden/2 const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts, - intermediate_hidden, config.num_max_pool_tokens, - config.block_k, config.load_block_m, + intermediate_hidden / 2, config.num_max_pool_tokens, + config.block_k / 2, config.load_block_m, static_cast(l2_acts.stride(-2)), - config.swizzle_acts_mode); + config.swizzle_acts_mode / 2, + 0, false, false); // fp4_unpacked_smem=false const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf, config.num_padded_sf_pool_tokens, intermediate_hidden, config.sf_block_m, kGranK, diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 9703a1c..38a887d 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -96,10 +96,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Token and buffer layouts // NVFP4: activations are E2M1 packed (2 per byte), so K/2 bytes per token - constexpr auto fp4_token_layout = layout::Data(kHidden); + constexpr auto fp4_token_layout = layout::Data(kHidden / 2); // packed: 2 E2M1 per byte constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); // NVFP4 intermediate: same E2M1 packing - constexpr auto fp4_intermediate_token_layout = layout::Data(kIntermediateHidden); + constexpr auto fp4_intermediate_token_layout = layout::Data(kIntermediateHidden / 2); // packed // NVFP4: group_size=16, so SF stride is K/16 (twice as many scales as MXFP4) constexpr auto nvfp4_sf_layout = layout::Data(kHidden / 16); constexpr auto nvfp4_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16); @@ -167,13 +167,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Data types // NVFP4: mxf4nvf4 requires BOTH A and B to be FP4 (E2M1 packed) - using a_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; - using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; - // Verify SMEM element sizes: unpacksmem = 1 byte/element (not 0.5) - // The TMA load unpacks 2 E2M1/global-byte → 2 SMEM bytes. - // UMMA reads unpacked SMEM and packs internally for mxf4nvf4. - static_assert(sizeof(a_dtype_t) == 1, "a_dtype_t (E2M1 unpacked SMEM) must be 1 byte/element"); - static_assert(sizeof(b_dtype_t) == 1, "b_dtype_t (E2M1 unpacked SMEM) must be 1 byte/element"); + using a_dtype_t = cutlass::float_e2m1_t; + using b_dtype_t = cutlass::float_e2m1_t; + // mxf4nvf4 reads packed FP4 from SMEM (2 values per byte), NOT unpacked. + // _unpacksmem_t was for mxf8f6f4 which reads FP4 as FP8 (1 byte/element). + // For mxf4nvf4: sizeof_bits = 4, SMEM stride = BLOCK_K/2 bytes, UMMA_K = 64. + static_assert(cutlass::sizeof_bits_v == 4, + "mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM"); + static_assert(cutlass::sizeof_bits_v == 4, + "mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM"); // MMA configs // NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major @@ -190,8 +192,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Swizzle configs // NVFP4: float_e2m1_unpacksmem_t uses 1 byte per element in SMEM (unpacked) // Same byte stride as FP8 — the TMA hardware unpacks 2 E2M1 values per global byte - constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); // 128 - constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); // 128 + // Packed FP4: BLOCK_K elements = BLOCK_K/2 bytes per row (2 values per byte) + constexpr uint32_t kSwizzleAMode = BLOCK_K / 2; // 64 bytes + constexpr uint32_t kSwizzleBMode = BLOCK_K / 2; // 64 bytes constexpr uint32_t kSwizzleCDMode = 128; DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N"); @@ -212,15 +215,14 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, math::constexpr_align(fp4_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); // NVFP4: float_e2m1_unpacksmem_t uses 1 byte per element in SMEM (unpacked) // TMA unpacks 2 E2M1 per global byte into 1 byte per element in SMEM - constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); - constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + // Packed FP4: 4 bits/element → LOAD_BLOCK_M * BLOCK_K / 2 bytes + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K / 2; + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K / 2; constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); - // L1 output: unpacked E2M1 in SMEM (1 byte/element, same as FP8 layout) - // TMA store writes to global as unpacked; L2 TMA load reads unpacked into SMEM - // This avoids the packed/unpacked mismatch at the L1→L2 boundary + // L1 output: packed E2M1 in SMEM (2 per byte), L1_OUT_BLOCK_N/2 bytes per row constexpr uint32_t SMEM_CD_L1_SIZE = - kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(uint8_t) * kNumTMAStoreStages; + kNumEpilogueWarpgroups * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) * sizeof(uint8_t) * kNumTMAStoreStages; constexpr uint32_t SMEM_CD_L2_SIZE = kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; @@ -573,7 +575,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, pull_buffer.get_base_ptr(), sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), current_rank_in_expert_idx), - pull_mbarrier, kHidden); // NVFP4: unpacked E2M1, same byte count as FP8 + pull_mbarrier, kHidden / 2); // NVFP4: packed E2M1, half the bytes } __syncwarp(); @@ -605,7 +607,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; // Wait for TMA token load to complete - ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); // NVFP4: unpacked E2M1, same byte count as FP8 + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden / 2); // NVFP4: packed E2M1, half the bytes ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); // Store token to local L1 buffer via TMA @@ -728,7 +730,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Compute token offset from pool block index uint32_t m_idx = pool_block_idx * BLOCK_M; - uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_idx = k_block_idx * (BLOCK_K / 2); // packed FP4: byte offset uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M; uint32_t sfa_k_idx = k_block_idx; @@ -738,7 +740,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // TMA copy tokens and SFA, then arrive at full barrier if (cute::elect_one_sync()) { - tma::copy( + tma::copy( tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2); tma::copy( tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); @@ -775,13 +777,14 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Compute weight offset uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; - uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_idx = k_block_idx * (BLOCK_K / 2); // packed FP4: byte offset uint32_t sfb_n_idx = n_block_idx * BLOCK_N; uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx; // TMA copy weights with SF if (cute::elect_one_sync()) { - tma::copy( + // NVFP4: weights are packed E2M1, BLOCK_K elements = BLOCK_K/2 bytes + tma::copy( tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2); tma::copy( tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); @@ -810,8 +813,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, auto sf_desc = mma::sm100::make_sf_desc(nullptr); DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); - auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); - auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + // NVFP4: UMMA descriptors use packed byte dimensions (BLOCK_K/2, uint8_t) + // because sizeof(float_e2m1_t)=1 but real stride is BLOCK_K/2 bytes per K-row + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; @@ -884,9 +889,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, const auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k); a_desc.lo = mma::sm100::advance_umma_desc_lo< - cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K); + cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(a_desc_base_lo, 0, k * (UMMA_K / 2)); b_desc.lo = mma::sm100::advance_umma_desc_lo< - cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K); + cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(b_desc_base_lo, 0, k * (UMMA_K / 2)); // NVFP4: use mxf4nvf4 instruction with UE4M3 scales ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma( b_desc, a_desc, accum_stage_idx * UMMA_N, @@ -1066,65 +1071,66 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, ptx::tma_store_wait(); ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); - // Quantize to E2M1 FP4 and STSM into shared memory (unpacked, 1 byte/element) - // NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is E2M1 - // SMEM layout: unpacked (1 byte per E2M1 element), same byte count as FP8 - // 4 BF16 → 4 E2M1 bytes → 1 STSM uint32 write, with XOR swizzle + // Quantize to E2M1 FP4 and write packed SMEM (2 values per byte) + // NVFP4: mxf4nvf4 requires FP4×FP4, so L1 output is packed E2M1 + // SMEM layout: packed, L1_OUT_BLOCK_N/2 bytes per row, no swizzle (v1) + auto quant_e2m1_byte = [](float v) -> uint8_t { + v = fmaxf(-6.0f, fminf(6.0f, v)); + uint8_t sign = (v < 0.0f) ? 1 : 0; + float aq = fabsf(v); + uint8_t idx; + if (aq < 0.25f) idx = 0; + else if (aq < 0.75f) idx = 1; + else if (aq < 1.25f) idx = 2; + else if (aq < 1.75f) idx = 3; + else if (aq < 2.5f) idx = 4; + else if (aq < 3.5f) idx = 5; + else if (aq < 5.0f) idx = 6; + else idx = 7; + return uint8_t((sign << 3) | idx); + }; + + // Lane mapping: row_in_atom = lane_idx/4, col_pair = lane_idx%4 + // Each warp owns L1_OUT_BLOCK_N/8 bytes per row (= 16 N-elements) + constexpr uint32_t kWarpBytesPerRow = L1_OUT_BLOCK_N / 8; + const uint32_t row_in_atom = lane_idx >> 2; + const uint32_t col_pair = lane_idx & 3; + #pragma unroll for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { - // Reduce amax + // Reduce amax across warp pair const float2 wp_amax = smem_amax_reduction[(epilogue_warp_idx ^ 1) * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx % 4]; amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x); amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y); - // Calculate SF for E2M1: scale = amax / 6.0 (E2M1 max = 6) + // E2M1: scale = amax / 6.0 (E2M1 max magnitude) float2 sf, sf_inv; sf.x = fmaxf(amax_values[i].x / 6.0f, 1e-8f); sf.y = fmaxf(amax_values[i].y / 6.0f, 1e-8f); sf_inv.x = 1.0f / sf.x; sf_inv.y = 1.0f / sf.y; - // E2M1 FP4 quantization: find nearest from [0, 0.5, 1, 1.5, 2, 3, 4, 6] - // Store as 4-bit in low nibble of each byte (unpacked SMEM: 1 byte/element) - auto quant_e2m1_byte = [](float v) -> uint8_t { - v = fmaxf(-6.0f, fminf(6.0f, v)); - uint8_t sign = (v < 0.0f) ? 1 : 0; - float aq = fabsf(v); - uint8_t idx; - if (aq < 0.25f) idx = 0; - else if (aq < 0.75f) idx = 1; - else if (aq < 1.25f) idx = 2; - else if (aq < 1.75f) idx = 3; - else if (aq < 2.5f) idx = 4; - else if (aq < 3.5f) idx = 5; - else if (aq < 5.0f) idx = 6; - else idx = 7; - return (sign << 3) | idx; // 4-bit E2M1 in low nibble - }; - - // Quantize 4 BF16 values → 4 E2M1 bytes → pack into uint32 for STSM + // 4 SwiGLU floats → 4 E2M1 nibbles → 2 bytes (uint16) for this lane const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); - uint8_t e0 = quant_e2m1_byte(upper.x); - uint8_t e1 = quant_e2m1_byte(upper.y); - uint8_t e2 = quant_e2m1_byte(lower.x); - uint8_t e3 = quant_e2m1_byte(lower.y); - // Pack 4 bytes into uint32 for one STSM atom - uint32_t packed = (uint32_t)e0 | ((uint32_t)e1 << 8) - | ((uint32_t)e2 << 16) | ((uint32_t)e3 << 24); + const uint8_t e0 = quant_e2m1_byte(upper.x); + const uint8_t e1 = quant_e2m1_byte(upper.y); + const uint8_t e2 = quant_e2m1_byte(lower.x); + const uint8_t e3 = quant_e2m1_byte(lower.y); + const uint16_t packed16 = + uint16_t(((e1 & 0xF) << 4) | (e0 & 0xF)) + | uint16_t((((e3 & 0xF) << 4) | (e2 & 0xF)) << 8); - // STSM with XOR swizzle (same layout as FP8, 1 byte/element) - uint32_t row = lane_idx; - uint32_t col = warp_idx_in_wg; - const auto smem_ptr = smem_cd[tma_stage_idx] - + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N - + i * ATOM_M * L1_OUT_BLOCK_N - + row * L1_OUT_BLOCK_N - + (col ^ (row / 2)) * kNumBankGroupBytes; - ptx::SM100_U32x1_STSM_T::copy(packed, smem_ptr); + // Direct shared store (st.shared.u16) — no STSM, no swizzle + uint8_t* smem_byte_ptr = smem_cd[tma_stage_idx] + + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) + + (i * ATOM_M + row_in_atom) * (L1_OUT_BLOCK_N / 2) + + warp_idx_in_wg * kWarpBytesPerRow + + col_pair * 2; + *reinterpret_cast(smem_byte_ptr) = packed16; - // Store SF to `l2_sf_buffer` as UE4M3 (MN-major layout) + // SF store to l2_sf_buffer as UE4M3 (MN-major layout) if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) { const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2; const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4; @@ -1134,26 +1140,26 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, __builtin_assume(token_base_idx < BLOCK_M); const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; - const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; + const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * uint32_t(sizeof(uint32_t)) + byte_idx; auto to_ue4m3 = [](float v) -> uint8_t { v = fmaxf(0.0f, fminf(v, 448.0f)); - cutlass::float_e4m3_t e4m3_val = cutlass::float_e4m3_t(v); - return reinterpret_cast(e4m3_val) & 0x7F; + cutlass::float_e4m3_t e = cutlass::float_e4m3_t(v); + return reinterpret_cast(e) & 0x7F; }; sf_base_ptr[sf_addr] = to_ue4m3(sf.x); - sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = to_ue4m3(sf.y); + sf_base_ptr[sf_addr + 4 * uint32_t(sizeof(uint32_t))] = to_ue4m3(sf.y); } __syncwarp(); } ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); - // Issue TMA store after all atoms in this store block + // TMA store SMEM → global (packed FP4 byte offsets) if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { - uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; // unpacked: 1 byte/element + const uint32_t out_n_idx = n_block_idx * (L1_OUT_BLOCK_N / 2); cute::tma_store_fence(); cute::SM90_TMA_STORE_2D::copy( &tensor_map_l1_output, - smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N), + smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2), out_n_idx, m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M); cute::tma_store_arrive();