fix: sf_id must be 0 for scale_vec::4X — passing sf_id=k was ILLEGAL_INSTRUCTION root cause
This commit is contained in:
@@ -927,8 +927,14 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
// Issue UMMA
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
// NVFP4 scale_vec::4X: sf_id must always be 0.
|
||||
// The hardware implicitly reads 4 SF positions per UMMA atom
|
||||
// from the single TMEM region [scale_A_tmem]/[scale_B_tmem].
|
||||
// Unlike scale_vec::1X (MXFP4) where each atom needs a unique sf_id
|
||||
// to index sub-columns, scale_vec::4X ignores sf_id or requires 0.
|
||||
// Passing sf_id=k (k=1 for second UMMA atom) was the ILLEGAL_INSTRUCTION bug.
|
||||
const auto runtime_instr_desc =
|
||||
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k);
|
||||
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, 0, 0);
|
||||
a_desc.lo = mma::sm100::advance_umma_desc_lo<
|
||||
cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(a_desc_base_lo, 0, k * (UMMA_K / 2));
|
||||
b_desc.lo = mma::sm100::advance_umma_desc_lo<
|
||||
|
||||
Reference in New Issue
Block a user