fix: sf_id must be 0 for scale_vec::4X — passing sf_id=k was ILLEGAL_INSTRUCTION root cause

This commit is contained in:
2026-05-12 20:07:19 +00:00
parent 4442c06ba8
commit 698634dea5

View File

@@ -927,8 +927,14 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// Issue UMMA
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
// NVFP4 scale_vec::4X: sf_id must always be 0.
// The hardware implicitly reads 4 SF positions per UMMA atom
// from the single TMEM region [scale_A_tmem]/[scale_B_tmem].
// Unlike scale_vec::1X (MXFP4) where each atom needs a unique sf_id
// to index sub-columns, scale_vec::4X ignores sf_id or requires 0.
// Passing sf_id=k (k=1 for second UMMA atom) was the ILLEGAL_INSTRUCTION bug.
const auto runtime_instr_desc =
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k);
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, 0, 0);
a_desc.lo = mma::sm100::advance_umma_desc_lo<
cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(a_desc_base_lo, 0, k * (UMMA_K / 2));
b_desc.lo = mma::sm100::advance_umma_desc_lo<