fix: double SMEM SF allocation for NVFP4 group=16 + clean stale comments

- SMEM_SFA/SFB_SIZE_PER_STAGE doubled: group=16 needs 8 SFs per token
  per BLOCK_K=128 (vs 4 for group=32)
- arrive_and_expect_tx updated to use SMEM_SFA/SFB constants
- Removed stale comments about 8/16 TMEM columns
This commit is contained in:
2026-05-11 23:58:07 +00:00
parent aa97a3f949
commit af092fa7ba

View File

@@ -218,8 +218,9 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// Packed FP4: 4 bits/element → LOAD_BLOCK_M * BLOCK_K / 2 bytes
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K / 2;
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K / 2;
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
// NVFP4 (group=16): 8 SFs per token per BLOCK_K=128, vs 4 for MXFP4 (group=32)
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t) * 2;
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t) * 2;
// L1 output: packed E2M1 in SMEM (2 per byte), L1_OUT_BLOCK_N/2 bytes per row
constexpr uint32_t SMEM_CD_L1_SIZE =
kNumEpilogueWarpgroups * STORE_BLOCK_M * (L1_OUT_BLOCK_N / 2) * sizeof(uint8_t) * kNumTMAStoreStages;
@@ -236,12 +237,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// Tensor memory size
constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages;
// NVFP4: scale_vec::4X → 4 SF per UMMA atom row → 4 TMEM cols per SF row
// For bM=128, SFA uses 4 rows × 4 cols = 16 TMEM columns
// SFB uses BLOCK_N/32 rows × 4 cols
// UTCCP 4x32dp128bit always writes 4 TMEM cols per 128-element group regardless of
// 1X vs 4X — the 4X only changes how the MMA interprets the scale factors, not how
// many TMEM columns UTCCP occupies. Same column count as MXFP4 (1X).
// NVFP4 scale_vec::4X: UTCCP writes 4 TMEM cols per 128-elem group (same as 1X).
// The 4X flag only changes MMA scale interpretation, not UTCCP layout.
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
@@ -748,7 +745,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
tma::copy<SF_BLOCK_M, 1, 0>(
tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2);
if (is_leader_cta) {
full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SF_BLOCK_M * sizeof(uint32_t) * 2);
full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SMEM_SFA_SIZE_PER_STAGE * 2);
} else {
full_barriers[stage_idx]->arrive(0u);
}
@@ -792,7 +789,7 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
tma::copy<BLOCK_N, 1, 0>(
tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2);
if (is_leader_cta) {
full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + BLOCK_N * sizeof(uint32_t) * 2);
full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE * 2);
} else {
full_barriers[stage_idx]->arrive(0u);
}
@@ -876,7 +873,6 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
// NVFP4 4X: 8 TMEM columns per 128-element SF group
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
#pragma unroll