fix: advance TMEM SF start column per UMMA atom for scale_vec::4X

This commit is contained in:
2026-05-12 20:56:35 +00:00
parent 74bf612771
commit 8b27e85ee5

View File

@@ -921,10 +921,17 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
b_desc.lo = mma::sm100::advance_umma_desc_lo<
cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(b_desc_base_lo, 0, k * (UMMA_K / 2));
// NVFP4: use mxf4nvf4 instruction with UE4M3 scales
// sf_id must be 0 for scale_vec::4X — hardware reads 4 SF positions
// per atom implicitly. Advance the TMEM start column for each UMMA
// atom so k=1 reads from the kk=1 SF region loaded by UTCCP.
// Per-K-column TMEM stride = (SF_BLOCK / 128) * 4 columns.
constexpr uint32_t kSFAKStride = (SF_BLOCK_M / 128) * 4;
constexpr uint32_t kSFBKStride = (SF_BLOCK_N / 128) * 4;
ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma(
b_desc, a_desc, accum_stage_idx * UMMA_N,
k_block_idx > 0 or k > 0, runtime_instr_desc,
kTmemStartColOfSFB, kTmemStartColOfSFA);
kTmemStartColOfSFB + k * kSFBKStride,
kTmemStartColOfSFA + k * kSFAKStride);
}
}
__syncwarp();