fix: correct TMEM column layout for scale_vec::4X
UTCCP 4x32dp128bit always writes 4 TMEM cols per 128-element group regardless of 1X vs 4X. The 4X only changes MMA interpretation, not UTCCP column count. Reverted from (*4, stride i*8) to (same as 1X, stride i*4): - kNumSFATmemCols: SF_BLOCK_M/32 (was SF_BLOCK_M/32*4) - kNumSFBTmemCols: SF_BLOCK_N/32 (was SF_BLOCK_N/32*4) - UTCCP stride: i*4 (was i*8)
This commit is contained in:
@@ -239,8 +239,11 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
// NVFP4: scale_vec::4X → 4 SF per UMMA atom row → 4 TMEM cols per SF row
|
||||
// For bM=128, SFA uses 4 rows × 4 cols = 16 TMEM columns
|
||||
// SFB uses BLOCK_N/32 rows × 4 cols
|
||||
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32 * 4;
|
||||
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32 * 4;
|
||||
// UTCCP 4x32dp128bit always writes 4 TMEM cols per 128-element group regardless of
|
||||
// 1X vs 4X — the 4X only changes how the MMA interprets the scale factors, not how
|
||||
// many TMEM columns UTCCP occupies. Same column count as MXFP4 (1X).
|
||||
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
|
||||
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
|
||||
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
|
||||
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
|
||||
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
|
||||
@@ -874,13 +877,13 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
// NVFP4 4X: 8 TMEM columns per 128-element SF group
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 8);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
|
||||
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 8);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
|
||||
}
|
||||
|
||||
// Issue UMMA
|
||||
|
||||
Reference in New Issue
Block a user