From 8b27e85ee5905a5b31fd6e11e00ba5b4a957eafe Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 20:56:35 +0000 Subject: [PATCH] fix: advance TMEM SF start column per UMMA atom for scale_vec::4X --- .../include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 2673d7c..88221b0 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -921,10 +921,17 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, b_desc.lo = mma::sm100::advance_umma_desc_lo< cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, uint8_t>(b_desc_base_lo, 0, k * (UMMA_K / 2)); // NVFP4: use mxf4nvf4 instruction with UE4M3 scales + // sf_id must be 0 for scale_vec::4X — hardware reads 4 SF positions + // per atom implicitly. Advance the TMEM start column for each UMMA + // atom so k=1 reads from the kk=1 SF region loaded by UTCCP. + // Per-K-column TMEM stride = (SF_BLOCK / 128) * 4 columns. + constexpr uint32_t kSFAKStride = (SF_BLOCK_M / 128) * 4; + constexpr uint32_t kSFBKStride = (SF_BLOCK_N / 128) * 4; ptx::SM100_MMA_MXF4NVF4_2x1SM_SS::fma( b_desc, a_desc, accum_stage_idx * UMMA_N, k_block_idx > 0 or k > 0, runtime_instr_desc, - kTmemStartColOfSFB, kTmemStartColOfSFA); + kTmemStartColOfSFB + k * kSFBKStride, + kTmemStartColOfSFA + k * kSFAKStride); } } __syncwarp();