fix: correct TMEM column layout for scale_vec::4X

UTCCP 4x32dp128bit always writes 4 TMEM cols per 128-element group
regardless of 1X vs 4X. The 4X only changes MMA interpretation,
not UTCCP column count. Reverted from (*4, stride i*8) to (same as 1X, stride i*4):
- kNumSFATmemCols: SF_BLOCK_M/32 (was SF_BLOCK_M/32*4)
- kNumSFBTmemCols: SF_BLOCK_N/32 (was SF_BLOCK_N/32*4)
- UTCCP stride: i*4 (was i*8)
This commit is contained in:
2026-05-11 23:44:12 +00:00
parent d6551617c0
commit aa97a3f949

View File

@@ -239,8 +239,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// NVFP4: scale_vec::4X → 4 SF per UMMA atom row → 4 TMEM cols per SF row
// For bM=128, SFA uses 4 rows × 4 cols = 16 TMEM columns
// SFB uses BLOCK_N/32 rows × 4 cols
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32 * 4;
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32 * 4;
// UTCCP 4x32dp128bit always writes 4 TMEM cols per 128-element group regardless of
// 1X vs 4X — the 4X only changes how the MMA interprets the scale factors, not how
// many TMEM columns UTCCP occupies. Same column count as MXFP4 (1X).
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
@@ -874,13 +877,13 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
// NVFP4 4X: 8 TMEM columns per 128-element SF group
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 8);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 8);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
}
// Issue UMMA