diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 7580ff4..7cc52d8 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -927,8 +927,14 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // Issue UMMA #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + // NVFP4 scale_vec::4X: sf_id must always be 0. + // The hardware implicitly reads 4 SF positions per UMMA atom + // from the single TMEM region [scale_A_tmem]/[scale_B_tmem]. + // Unlike scale_vec::1X (MXFP4) where each atom needs a unique sf_id + // to index sub-columns, scale_vec::4X ignores sf_id or requires 0. + // Passing sf_id=k (k=1 for second UMMA atom) was the ILLEGAL_INSTRUCTION bug. const auto runtime_instr_desc = - mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k); + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, 0, 0); a_desc.lo = mma::sm100::advance_umma_desc_lo< cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(a_desc_base_lo, 0, k * (UMMA_K / 2)); b_desc.lo = mma::sm100::advance_umma_desc_lo<