diff --git a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp index 12a78d3..f99de14 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp @@ -157,10 +157,10 @@ static void sm100_fp8_nvfp4_mega_moe( intermediate_hidden * 2, hidden, config.block_n, kGranK, num_experts_per_rank, 0); - // L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/2 bytes (packed), no swizzle (v1) + // L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/4 bytes (SwiGLU halving × FP4 packing), no swizzle (v1) const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, intermediate_hidden / 2, config.num_max_pool_tokens, - config.block_n / 2, config.store_block_m, + config.block_n / 4, config.store_block_m, static_cast(l2_acts.stride(-2)), 0, 0, // no swizzle false, // allow_tf32