diff --git a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp index 44038dc..17948cd 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp @@ -148,10 +148,11 @@ static void sm100_fp8_nvfp4_mega_moe( config.sf_block_m, kGranK, 1, 0); const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights, - hidden, num_experts_per_rank * intermediate_hidden * 2, - config.block_k, config.load_block_n, + hidden / 2, num_experts_per_rank * intermediate_hidden * 2, + config.block_k / 2, config.load_block_n, static_cast(l1_weights.stride(-2)), - config.swizzle_weights_mode); + config.swizzle_weights_mode / 2, + 0, false, false); // fp4_unpacked_smem=false (packed!) const auto tensor_map_l1_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_weights_sf, intermediate_hidden * 2, hidden, config.block_n, kGranK, @@ -176,10 +177,11 @@ static void sm100_fp8_nvfp4_mega_moe( config.sf_block_m, kGranK, 1, 0); const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights, - intermediate_hidden, num_experts_per_rank * hidden, - config.block_k, config.load_block_n, + intermediate_hidden / 2, num_experts_per_rank * hidden, + config.block_k / 2, config.load_block_n, static_cast(l2_weights.stride(-2)), - config.swizzle_weights_mode); + config.swizzle_weights_mode / 2, + 0, false, false); // fp4_unpacked_smem=false (packed!) const auto tensor_map_l2_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_weights_sf, hidden, intermediate_hidden, config.block_n, kGranK,