From 26a8ab75a11b8d26e210a666cb81a83a8446d368 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 08:08:17 +0000 Subject: [PATCH] =?UTF-8?q?NVFP4:=20fix=20SF=20pipeline=20=E2=80=94=202=20?= =?UTF-8?q?K-cols=20per=20BLOCK=5FK=20for=20group=3D16?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TMA: issue two tma::copy calls per K-block (K_box=1, 2 SF K-columns) - UTCCP: double loop for 2 K-columns, correct SMEM offsets - TMEM: double SFA/SFB column counts (SF_BLOCK_M/32 * 2) - Heuristic: fix smem_size (2× SF, packed FP4 A/B, packed send buffers, no amax) - Staging kernel: fix double-count bug in packed_k_mask --- csrc/jit_kernels/heuristics/mega_moe.hpp | 21 ++++--- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 57 ++++++++++++------- 2 files changed, 51 insertions(+), 27 deletions(-) diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index b1ba6bd..26be5fd 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -138,32 +138,37 @@ static std::pair get_pipeline_config_for_mega_moe( // Dispatch region const int smem_expert_count_size = align( num_experts * static_cast(sizeof(uint32_t)), kSmemAlignment); + // NVFP4: dispatch send buffers use packed E2M1 tokens (hidden/2 bytes per token) const int smem_send_buffers_size = align( - static_cast(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()), + static_cast(layout::Buffer(layout::Data(hidden / 2), num_dispatch_warps, 1).get_num_bytes()), kSmemAlignment); const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size; - // C/D output region: max of L1 FP8 (2 TMA stages, BLOCK_N/2 post-SwiGLU) and L2 BF16 (1 stage) + // C/D output region: max of L1 packed E2M1 (2 TMA stages, BLOCK_N/4 bytes per row) and L2 BF16 (1 stage) + // NVFP4 L1 output: packed E2M1 (2 per byte), L1_OUT_BLOCK_N = block_n/2, bytes = L1_OUT_BLOCK_N/2 = block_n/4 const auto num_epilogue_warpgroups = num_epilogue_warps / 4; - const int smem_cd_l1 = num_epilogue_warpgroups * store_block_m * (block_n / 2) * kNumTMAStoreStages; + const int smem_cd_l1 = num_epilogue_warpgroups * store_block_m * (block_n / 4) * kNumTMAStoreStages; const int smem_cd_l2 = num_epilogue_warpgroups * store_block_m * block_n * static_cast(sizeof(nv_bfloat16)); const int smem_cd = std::max(smem_cd_l1, smem_cd_l2); // Barriers (stage-independent): dispatch + tensor memory full/empty + combine (2 per epilogue warp) const int smem_barriers = (num_dispatch_warps + kNumEpilogueStages * 2 + num_epilogue_warps * 2) * 8; - // Amax reduction - const int smem_amax_reduction = store_block_m * num_epilogue_warps * static_cast(sizeof(float)); + // NVFP4: no SMEM amax reduction needed (each warp computes its own amax) + const int smem_amax_reduction = 0; // Tensor memory pointer const int smem_tmem_ptr = 4; // SF is aligned to UTCCP 128-element granularity - const int smem_sfa_per_stage = sf_block_m * 4; - const int smem_sfb_per_stage = sf_block_n * 4; + // NVFP4: group=16 → 2 SF K-columns per BLOCK_K (128/16/4=2) + // Each K-column: sf_block_m * 4 bytes (uint32), total = 2× the MXFP4 SF size + const int smem_sfa_per_stage = sf_block_m * 4 * 2; // 2 K-cols for NVFP4 + const int smem_sfb_per_stage = sf_block_n * 4 * 2; // 2 K-cols for NVFP4 // Per-stage: A tile + B tile + SFA tile + SFB tile + full/empty barriers - const int smem_per_stage = load_block_m * block_k + block_n * block_k + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8; + // NVFP4: packed E2M1 (2 per byte), so A/B tiles use BLOCK_K/2 bytes per row + const int smem_per_stage = load_block_m * (block_k / 2) + block_n * (block_k / 2) + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8; // Fixed total const int smem_fixed = smem_dispatch_size + smem_cd + smem_amax_reduction + smem_barriers + smem_tmem_ptr; diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index be0410c..05b43c3 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -237,10 +237,10 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Tensor memory size constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; - // NVFP4 scale_vec::4X: UTCCP writes 4 TMEM cols per 128-elem group (same as 1X). - // The 4X flag only changes MMA scale interpretation, not UTCCP layout. - constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; - constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + // NVFP4 group=16: 2 SF K-columns per BLOCK_K, so 2× the TMEM cols vs MXFP4 + // Each 128-elem UTCCP group → 4 TMEM cols, and we have 2 groups per BLOCK_K + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32 * 2; // 2 K-cols + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32 * 2; // 2 K-cols constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; @@ -734,7 +734,6 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, uint32_t m_idx = pool_block_idx * BLOCK_M; uint32_t k_idx = k_block_idx * (BLOCK_K / 2); // packed FP4: byte offset uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M; - uint32_t sfa_k_idx = k_block_idx; // Add 2 CTA offsets for non-leader CTA if (not is_leader_cta) @@ -744,8 +743,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, if (cute::elect_one_sync()) { tma::copy( tensor_map_a_ptr, full_barriers[stage_idx], reinterpret_cast(smem_a[stage_idx]), k_idx, m_idx, 2); - tma::copy( - tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); + // NVFP4 group=16: 2 SF K-columns per BLOCK_K (128/16/4=2), but TMA K_box=1 + // Issue two TMA copies per K-block to load both SF columns + for (uint32_t kk = 0; kk < 2; ++kk) { + const uint32_t sfa_k_idx_full = k_block_idx * 2 + kk; + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], + smem_sfa[stage_idx] + kk * SF_BLOCK_M, // uint32* offset: SF_BLOCK_M ints per call + sfa_m_idx, sfa_k_idx_full, 2); + } if (is_leader_cta) { full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SMEM_SFA_SIZE_PER_STAGE * 2); } else { @@ -781,15 +787,20 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; uint32_t k_idx = k_block_idx * (BLOCK_K / 2); // packed FP4: byte offset uint32_t sfb_n_idx = n_block_idx * BLOCK_N; - uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx; // TMA copy weights with SF if (cute::elect_one_sync()) { // NVFP4: weights are packed E2M1, BLOCK_K elements = BLOCK_K/2 bytes tma::copy( tensor_map_b_ptr, full_barriers[stage_idx], reinterpret_cast(smem_b[stage_idx]), k_idx, n_idx, 2); - tma::copy( - tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); + // NVFP4 group=16: 2 SF K-columns per BLOCK_K, but TMA K_box=1 + for (uint32_t kk = 0; kk < 2; ++kk) { + const uint32_t sfb_k_idx_full = local_expert_idx * shape_sfb_k + k_block_idx * 2 + kk; + tma::copy( + tensor_map_sfb_ptr, full_barriers[stage_idx], + smem_sfb[stage_idx] + kk * BLOCK_N, // uint32* offset: BLOCK_N ints per call + sfb_n_idx, sfb_k_idx_full, 2); + } if (is_leader_cta) { full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE * 2); } else { @@ -868,20 +879,28 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); if (cute::elect_one_sync()) { // UTCCP copy SFA and SFB to TMEM - // NVFP4: scale_vec::4X, each 128-element block → 8 TMEM cols + // NVFP4: group=16 → 2 SF K-columns per BLOCK_K (128/16/4=2) + // Each UTCCP call moves 128 int32s → 4 TMEM cols + // We need 2 UTCCP calls per SF: one per K-column using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; #pragma unroll - for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { - auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; - mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); - cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + for (uint32_t kk = 0; kk < 2; ++kk) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + kk * SF_BLOCK_M + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + (kk * (SF_BLOCK_M / kNumUTCCPAlignedElems) + i) * 4); + } } #pragma unroll - for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { - auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; - mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); - cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + for (uint32_t kk = 0; kk < 2; ++kk) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + kk * BLOCK_N + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + (kk * (SF_BLOCK_N / kNumUTCCPAlignedElems) + i) * 4); + } } // Issue UMMA