fix: advance TMEM SF start column per UMMA atom for scale_vec::4X
This commit is contained in:
@@ -921,10 +921,17 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
b_desc.lo = mma::sm100::advance_umma_desc_lo<
|
||||
cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(b_desc_base_lo, 0, k * (UMMA_K / 2));
|
||||
// NVFP4: use mxf4nvf4 instruction with UE4M3 scales
|
||||
// sf_id must be 0 for scale_vec::4X — hardware reads 4 SF positions
|
||||
// per atom implicitly. Advance the TMEM start column for each UMMA
|
||||
// atom so k=1 reads from the kk=1 SF region loaded by UTCCP.
|
||||
// Per-K-column TMEM stride = (SF_BLOCK / 128) * 4 columns.
|
||||
constexpr uint32_t kSFAKStride = (SF_BLOCK_M / 128) * 4;
|
||||
constexpr uint32_t kSFBKStride = (SF_BLOCK_N / 128) * 4;
|
||||
ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma(
|
||||
b_desc, a_desc, accum_stage_idx * UMMA_N,
|
||||
k_block_idx > 0 or k > 0, runtime_instr_desc,
|
||||
kTmemStartColOfSFB, kTmemStartColOfSFA);
|
||||
kTmemStartColOfSFB + k * kSFBKStride,
|
||||
kTmemStartColOfSFA + k * kSFAKStride);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
Reference in New Issue
Block a user