fix: TMA dimensions for packed FP4 must be in individual FP4 values (not bytes)

CUDA docs: 'Dimension for the packed data types must reflect the number
of individual U# values.' For 16U4_ALIGN8B, gmem/smem inner dims must be
FP4 value counts, not byte counts. Double the byte-oriented dimensions
passed by callers. gmem_outer_stride stays in bytes.
This commit is contained in:
2026-05-12 17:39:07 +00:00
parent 75f1c8544b
commit 48b5b2b702

View File

@@ -116,17 +116,32 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false,
const bool& fp4_unpacked_smem = true) {
const auto elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_inner_dim = swizzle_mode / elem_size;
if (t.scalar_type() == kPackedFP4) {
// Inner dim must be a multiple of 64B for .b4x16_p64
DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0);
// For packed FP4 (mxf4nvf4): smem_inner_dim must match the MMA's expected
// SMEM row width (BLOCK_K/2 bytes). The default swizzle/elem_size gives the
// correct value — do NOT double it. The TMA and MMA must agree on row width.
// For packed FP4 with 16U4_ALIGN8B: CUDA TMA dimensions must be in
// individual FP4 values (not bytes). The caller passes byte-oriented
// dimensions, so we double them. The gmem_outer_stride stays in bytes.
// swizzle_mode is also byte-oriented, so we double it for the smem_inner_dim
// calculation below (which divides by elem_size=1, giving byte count).
if (not fp4_unpacked_smem) {
gmem_inner_dim *= 2;
smem_inner_dim *= 2;
gmem_outer_dim *= 1; // outer dim is in rows, not FP4 values
smem_outer_dim *= 1; // same
// smem_inner_dim will be recalculated from swizzle below if swizzle != 0
}
}
if (swizzle_mode != 0) {
if (t.scalar_type() == kPackedFP4 and not fp4_unpacked_smem) {
// For packed FP4: swizzle_mode is in bytes, but smem_inner_dim is in FP4 values
// swizzle_mode / elem_size gives bytes; *2 for FP4 values
smem_inner_dim = (swizzle_mode / elem_size) * 2;
} else {
smem_inner_dim = swizzle_mode / elem_size;
}
}
CUtensorMap tensor_map;