From af092fa7ba339881a7c587888421f9cd1fc984d9 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 23:58:07 +0000 Subject: [PATCH] fix: double SMEM SF allocation for NVFP4 group=16 + clean stale comments - SMEM_SFA/SFB_SIZE_PER_STAGE doubled: group=16 needs 8 SFs per token per BLOCK_K=128 (vs 4 for group=32) - arrive_and_expect_tx updated to use SMEM_SFA/SFB constants - Removed stale comments about 8/16 TMEM columns --- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 518ce4c..36a1b1b 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -218,8 +218,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Packed FP4: 4 bits/element → LOAD_BLOCK_M * BLOCK_K / 2 bytes constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K / 2; constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K / 2; - constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); - constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + // NVFP4 (group=16): 8 SFs per token per BLOCK_K=128, vs 4 for MXFP4 (group=32) + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t) * 2; + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t) * 2; // L1 output: packed E2M1 in SMEM (2 per byte), L1_OUT_BLOCK_N/2 bytes per row constexpr uint32_t SMEM_CD_L1_SIZE = kNumEpilogueWarpgroups * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) * sizeof(uint8_t) * kNumTMAStoreStages; @@ -236,12 +237,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Tensor memory size constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; - // NVFP4: scale_vec::4X → 4 SF per UMMA atom row → 4 TMEM cols per SF row - // For bM=128, SFA uses 4 rows × 4 cols = 16 TMEM columns - // SFB uses BLOCK_N/32 rows × 4 cols - // UTCCP 4x32dp128bit always writes 4 TMEM cols per 128-element group regardless of - // 1X vs 4X — the 4X only changes how the MMA interprets the scale factors, not how - // many TMEM columns UTCCP occupies. Same column count as MXFP4 (1X). + // NVFP4 scale_vec::4X: UTCCP writes 4 TMEM cols per 128-elem group (same as 1X). + // The 4X flag only changes MMA scale interpretation, not UTCCP layout. constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); @@ -748,7 +745,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, tma::copy( tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); if (is_leader_cta) { - full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SF_BLOCK_M * sizeof(uint32_t) * 2); + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SMEM_SFA_SIZE_PER_STAGE * 2); } else { full_barriers[stage_idx]->arrive(0u); } @@ -792,7 +789,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, tma::copy( tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); if (is_leader_cta) { - full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + BLOCK_N * sizeof(uint32_t) * 2); + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE * 2); } else { full_barriers[stage_idx]->arrive(0u); } @@ -876,7 +873,6 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); - // NVFP4 4X: 8 TMEM columns per 128-element SF group cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); } #pragma unroll