fix: restore elem_size declaration for TMA desc build

This commit is contained in:
2026-05-12 17:40:25 +00:00
parent 48b5b2b702
commit b0094175a2

View File

@@ -116,6 +116,8 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false,
const bool& fp4_unpacked_smem = true) {
const auto elem_size = static_cast<int>(t.element_size());
if (t.scalar_type() == kPackedFP4) {
// Inner dim must be a multiple of 64B for .b4x16_p64
DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0);
@@ -123,13 +125,10 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
// For packed FP4 with 16U4_ALIGN8B: CUDA TMA dimensions must be in
// individual FP4 values (not bytes). The caller passes byte-oriented
// dimensions, so we double them. The gmem_outer_stride stays in bytes.
// swizzle_mode is also byte-oriented, so we double it for the smem_inner_dim
// calculation below (which divides by elem_size=1, giving byte count).
if (not fp4_unpacked_smem) {
gmem_inner_dim *= 2;
smem_inner_dim *= 2;
gmem_outer_dim *= 1; // outer dim is in rows, not FP4 values
smem_outer_dim *= 1; // same
// outer dim is in rows, not FP4 values — no doubling
// smem_inner_dim will be recalculated from swizzle below if swizzle != 0
}
}