fix: use UINT8 TMA for packed FP4 instead of 16U4_ALIGN8B

The 16U4_ALIGN8B TMA data type is not supported on this driver
(CUDA_ERROR_INVALID_VALUE). Use UINT8 TMA to load raw bytes and let
the UMMA descriptor interpret SMEM as packed FP4 for mxf4nvf4.
TMA dimensions stay in bytes (like UINT8).
This commit is contained in:
2026-05-12 18:05:11 +00:00
parent b0094175a2
commit c56f5dda7e

View File

@@ -83,8 +83,10 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
#if CUDA_VERSION >= 12080
case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B
: CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;
case kPackedFP4: // For mxf4nvf4 packed FP4: use UINT8 TMA instead of 16U4.
// The 16U4 type causes CUDA_ERROR_INVALID_VALUE on many drivers.
// UMMA descriptor handles FP4 interpretation of SMEM.
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
#endif
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
@@ -117,30 +119,16 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
const bool& allow_tf32 = false,
const bool& fp4_unpacked_smem = true) {
const auto elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_inner_dim = swizzle_mode / elem_size;
if (t.scalar_type() == kPackedFP4) {
// Inner dim must be a multiple of 64B for .b4x16_p64
DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0);
// For packed FP4 with 16U4_ALIGN8B: CUDA TMA dimensions must be in
// individual FP4 values (not bytes). The caller passes byte-oriented
// dimensions, so we double them. The gmem_outer_stride stays in bytes.
if (not fp4_unpacked_smem) {
gmem_inner_dim *= 2;
smem_inner_dim *= 2;
// outer dim is in rows, not FP4 values — no doubling
// smem_inner_dim will be recalculated from swizzle below if swizzle != 0
}
}
if (swizzle_mode != 0) {
if (t.scalar_type() == kPackedFP4 and not fp4_unpacked_smem) {
// For packed FP4: swizzle_mode is in bytes, but smem_inner_dim is in FP4 values
// swizzle_mode / elem_size gives bytes; *2 for FP4 values
smem_inner_dim = (swizzle_mode / elem_size) * 2;
} else {
smem_inner_dim = swizzle_mode / elem_size;
}
// For packed FP4 (mxf4nvf4): use UINT8 TMA instead of 16U4_ALIGN8B.
// The 16U4 TMA type is not widely supported (causes CUDA_ERROR_INVALID_VALUE).
// We load raw bytes via UINT8 and let the UMMA descriptor interpret
// the SMEM layout as packed FP4. Dimensions stay in bytes (like UINT8).
}
CUtensorMap tensor_map;