From c0850a68596e98be4dbb517a819e73edd9834167 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 06:51:39 +0000 Subject: [PATCH] Fix weight TMA descriptors: packed E2M1 needs K/2, block_k/2, swizzle/2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Weights are packed E2M1 (2 per byte) but TMA descriptors were using unpacked dimensions — K-dim in elements instead of bytes, 128B swizzle instead of 64B, full block_k instead of block_k/2. This caused OOB reads and swizzle mismatch with the UMMA descriptor, producing illegal instruction traps. --- .../jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp index 44038dc..17948cd 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp @@ -148,10 +148,11 @@ static void sm100_fp8_nvfp4_mega_moe( config.sf_block_m, kGranK, 1, 0); const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights, - hidden, num_experts_per_rank * intermediate_hidden * 2, - config.block_k, config.load_block_n, + hidden / 2, num_experts_per_rank * intermediate_hidden * 2, + config.block_k / 2, config.load_block_n, static_cast(l1_weights.stride(-2)), - config.swizzle_weights_mode); + config.swizzle_weights_mode / 2, + 0, false, false); // fp4_unpacked_smem=false (packed!) const auto tensor_map_l1_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_weights_sf, intermediate_hidden * 2, hidden, config.block_n, kGranK, @@ -176,10 +177,11 @@ static void sm100_fp8_nvfp4_mega_moe( config.sf_block_m, kGranK, 1, 0); const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights, - intermediate_hidden, num_experts_per_rank * hidden, - config.block_k, config.load_block_n, + intermediate_hidden / 2, num_experts_per_rank * hidden, + config.block_k / 2, config.load_block_n, static_cast(l2_weights.stride(-2)), - config.swizzle_weights_mode); + config.swizzle_weights_mode / 2, + 0, false, false); // fp4_unpacked_smem=false (packed!) const auto tensor_map_l2_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_weights_sf, hidden, intermediate_hidden, config.block_n, kGranK,