From 75f1c8544b6649b41f4b0908225cb954cfae9b13 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 17:14:44 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20remove=20smem=5Finner=5Fdim=20doubling?= =?UTF-8?q?=20for=20packed=20FP4=20TMA=20=E2=80=94=20must=20match=20MMA=20?= =?UTF-8?q?row=20width=20(BLOCK=5FK/2)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- csrc/jit_kernels/impls/runtime_utils.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index 72a76f0..62f06bd 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -124,9 +124,9 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, // Inner dim must be a multiple of 64B for .b4x16_p64 DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0); - // Fix FP4 packed smem - if (not fp4_unpacked_smem and swizzle_mode != 0) - smem_inner_dim = swizzle_mode * 2; + // For packed FP4 (mxf4nvf4): smem_inner_dim must match the MMA's expected + // SMEM row width (BLOCK_K/2 bytes). The default swizzle/elem_size gives the + // correct value — do NOT double it. The TMA and MMA must agree on row width. } CUtensorMap tensor_map;