fix: TMA dimensions for packed FP4 must be in individual FP4 values (not bytes)
CUDA docs: 'Dimension for the packed data types must reflect the number of individual U# values.' For 16U4_ALIGN8B, gmem/smem inner dims must be FP4 value counts, not byte counts. Double the byte-oriented dimensions passed by callers. gmem_outer_stride stays in bytes.
This commit is contained in:
@@ -116,17 +116,32 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false,
|
||||
const bool& fp4_unpacked_smem = true) {
|
||||
const auto elem_size = static_cast<int>(t.element_size());
|
||||
if (swizzle_mode != 0)
|
||||
smem_inner_dim = swizzle_mode / elem_size;
|
||||
|
||||
if (t.scalar_type() == kPackedFP4) {
|
||||
// Inner dim must be a multiple of 64B for .b4x16_p64
|
||||
DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0);
|
||||
|
||||
// For packed FP4 (mxf4nvf4): smem_inner_dim must match the MMA's expected
|
||||
// SMEM row width (BLOCK_K/2 bytes). The default swizzle/elem_size gives the
|
||||
// correct value — do NOT double it. The TMA and MMA must agree on row width.
|
||||
// For packed FP4 with 16U4_ALIGN8B: CUDA TMA dimensions must be in
|
||||
// individual FP4 values (not bytes). The caller passes byte-oriented
|
||||
// dimensions, so we double them. The gmem_outer_stride stays in bytes.
|
||||
// swizzle_mode is also byte-oriented, so we double it for the smem_inner_dim
|
||||
// calculation below (which divides by elem_size=1, giving byte count).
|
||||
if (not fp4_unpacked_smem) {
|
||||
gmem_inner_dim *= 2;
|
||||
smem_inner_dim *= 2;
|
||||
gmem_outer_dim *= 1; // outer dim is in rows, not FP4 values
|
||||
smem_outer_dim *= 1; // same
|
||||
// smem_inner_dim will be recalculated from swizzle below if swizzle != 0
|
||||
}
|
||||
}
|
||||
|
||||
if (swizzle_mode != 0) {
|
||||
if (t.scalar_type() == kPackedFP4 and not fp4_unpacked_smem) {
|
||||
// For packed FP4: swizzle_mode is in bytes, but smem_inner_dim is in FP4 values
|
||||
// swizzle_mode / elem_size gives bytes; *2 for FP4 values
|
||||
smem_inner_dim = (swizzle_mode / elem_size) * 2;
|
||||
} else {
|
||||
smem_inner_dim = swizzle_mode / elem_size;
|
||||
}
|
||||
}
|
||||
|
||||
CUtensorMap tensor_map;
|
||||
|
||||
Reference in New Issue
Block a user