From b0094175a2fe8708a9a86575c2ed2846ccc05f20 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 17:40:25 +0000 Subject: [PATCH] fix: restore elem_size declaration for TMA desc build --- csrc/jit_kernels/impls/runtime_utils.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index 6f8a3d7..f425dc5 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -116,6 +116,8 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, const int& swizzle_mode, const int& swizzle_base = 0, const bool& allow_tf32 = false, const bool& fp4_unpacked_smem = true) { + const auto elem_size = static_cast(t.element_size()); + if (t.scalar_type() == kPackedFP4) { // Inner dim must be a multiple of 64B for .b4x16_p64 DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0); @@ -123,13 +125,10 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, // For packed FP4 with 16U4_ALIGN8B: CUDA TMA dimensions must be in // individual FP4 values (not bytes). The caller passes byte-oriented // dimensions, so we double them. The gmem_outer_stride stays in bytes. - // swizzle_mode is also byte-oriented, so we double it for the smem_inner_dim - // calculation below (which divides by elem_size=1, giving byte count). if (not fp4_unpacked_smem) { gmem_inner_dim *= 2; smem_inner_dim *= 2; - gmem_outer_dim *= 1; // outer dim is in rows, not FP4 values - smem_outer_dim *= 1; // same + // outer dim is in rows, not FP4 values — no doubling // smem_inner_dim will be recalculated from swizzle below if swizzle != 0 } }