diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index 72a76f0..62f06bd 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -124,9 +124,9 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, // Inner dim must be a multiple of 64B for .b4x16_p64 DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0); - // Fix FP4 packed smem - if (not fp4_unpacked_smem and swizzle_mode != 0) - smem_inner_dim = swizzle_mode * 2; + // For packed FP4 (mxf4nvf4): smem_inner_dim must match the MMA's expected + // SMEM row width (BLOCK_K/2 bytes). The default swizzle/elem_size gives the + // correct value — do NOT double it. The TMA and MMA must agree on row width. } CUtensorMap tensor_map;