fix: restore elem_size declaration for TMA desc build
This commit is contained in:
@@ -116,6 +116,8 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false,
|
||||
const bool& fp4_unpacked_smem = true) {
|
||||
const auto elem_size = static_cast<int>(t.element_size());
|
||||
|
||||
if (t.scalar_type() == kPackedFP4) {
|
||||
// Inner dim must be a multiple of 64B for .b4x16_p64
|
||||
DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0);
|
||||
@@ -123,13 +125,10 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
// For packed FP4 with 16U4_ALIGN8B: CUDA TMA dimensions must be in
|
||||
// individual FP4 values (not bytes). The caller passes byte-oriented
|
||||
// dimensions, so we double them. The gmem_outer_stride stays in bytes.
|
||||
// swizzle_mode is also byte-oriented, so we double it for the smem_inner_dim
|
||||
// calculation below (which divides by elem_size=1, giving byte count).
|
||||
if (not fp4_unpacked_smem) {
|
||||
gmem_inner_dim *= 2;
|
||||
smem_inner_dim *= 2;
|
||||
gmem_outer_dim *= 1; // outer dim is in rows, not FP4 values
|
||||
smem_outer_dim *= 1; // same
|
||||
// outer dim is in rows, not FP4 values — no doubling
|
||||
// smem_inner_dim will be recalculated from swizzle below if swizzle != 0
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user