From 48b5b2b702f487467024bfe62a271d0ba0b4fc15 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 17:39:07 +0000 Subject: [PATCH] fix: TMA dimensions for packed FP4 must be in individual FP4 values (not bytes) CUDA docs: 'Dimension for the packed data types must reflect the number of individual U# values.' For 16U4_ALIGN8B, gmem/smem inner dims must be FP4 value counts, not byte counts. Double the byte-oriented dimensions passed by callers. gmem_outer_stride stays in bytes. --- csrc/jit_kernels/impls/runtime_utils.hpp | 29 ++++++++++++++++++------ 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index 62f06bd..6f8a3d7 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -116,17 +116,32 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, const int& swizzle_mode, const int& swizzle_base = 0, const bool& allow_tf32 = false, const bool& fp4_unpacked_smem = true) { - const auto elem_size = static_cast(t.element_size()); - if (swizzle_mode != 0) - smem_inner_dim = swizzle_mode / elem_size; - if (t.scalar_type() == kPackedFP4) { // Inner dim must be a multiple of 64B for .b4x16_p64 DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0); - // For packed FP4 (mxf4nvf4): smem_inner_dim must match the MMA's expected - // SMEM row width (BLOCK_K/2 bytes). The default swizzle/elem_size gives the - // correct value — do NOT double it. The TMA and MMA must agree on row width. + // For packed FP4 with 16U4_ALIGN8B: CUDA TMA dimensions must be in + // individual FP4 values (not bytes). The caller passes byte-oriented + // dimensions, so we double them. The gmem_outer_stride stays in bytes. + // swizzle_mode is also byte-oriented, so we double it for the smem_inner_dim + // calculation below (which divides by elem_size=1, giving byte count). + if (not fp4_unpacked_smem) { + gmem_inner_dim *= 2; + smem_inner_dim *= 2; + gmem_outer_dim *= 1; // outer dim is in rows, not FP4 values + smem_outer_dim *= 1; // same + // smem_inner_dim will be recalculated from swizzle below if swizzle != 0 + } + } + + if (swizzle_mode != 0) { + if (t.scalar_type() == kPackedFP4 and not fp4_unpacked_smem) { + // For packed FP4: swizzle_mode is in bytes, but smem_inner_dim is in FP4 values + // swizzle_mode / elem_size gives bytes; *2 for FP4 values + smem_inner_dim = (swizzle_mode / elem_size) * 2; + } else { + smem_inner_dim = swizzle_mode / elem_size; + } } CUtensorMap tensor_map;