From aa97a3f9497b915881707df26fe25658055a8a7b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 23:44:12 +0000 Subject: [PATCH] fix: correct TMEM column layout for scale_vec::4X UTCCP 4x32dp128bit always writes 4 TMEM cols per 128-element group regardless of 1X vs 4X. The 4X only changes MMA interpretation, not UTCCP column count. Reverted from (*4, stride i*8) to (same as 1X, stride i*4): - kNumSFATmemCols: SF_BLOCK_M/32 (was SF_BLOCK_M/32*4) - kNumSFBTmemCols: SF_BLOCK_N/32 (was SF_BLOCK_N/32*4) - UTCCP stride: i*4 (was i*8) --- .../deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index dc77374..518ce4c 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -239,8 +239,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // NVFP4: scale_vec::4X → 4 SF per UMMA atom row → 4 TMEM cols per SF row // For bM=128, SFA uses 4 rows × 4 cols = 16 TMEM columns // SFB uses BLOCK_N/32 rows × 4 cols - constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32 * 4; - constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32 * 4; + // UTCCP 4x32dp128bit always writes 4 TMEM cols per 128-element group regardless of + // 1X vs 4X — the 4X only changes how the MMA interprets the scale factors, not how + // many TMEM columns UTCCP occupies. Same column count as MXFP4 (1X). + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; @@ -874,13 +877,13 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); // NVFP4 4X: 8 TMEM columns per 128-element SF group - cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 8); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); } #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); - cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 8); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); } // Issue UMMA