diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index f425dc5..388a2ac 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -83,8 +83,10 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8; #if CUDA_VERSION >= 12080 - case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B - : CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; + case kPackedFP4: // For mxf4nvf4 packed FP4: use UINT8 TMA instead of 16U4. + // The 16U4 type causes CUDA_ERROR_INVALID_VALUE on many drivers. + // UMMA descriptor handles FP4 interpretation of SMEM. + return CU_TENSOR_MAP_DATA_TYPE_UINT8; #endif default: DG_HOST_UNREACHABLE("Unsupported dtype"); } @@ -117,30 +119,16 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, const bool& allow_tf32 = false, const bool& fp4_unpacked_smem = true) { const auto elem_size = static_cast(t.element_size()); + if (swizzle_mode != 0) + smem_inner_dim = swizzle_mode / elem_size; if (t.scalar_type() == kPackedFP4) { // Inner dim must be a multiple of 64B for .b4x16_p64 DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0); - - // For packed FP4 with 16U4_ALIGN8B: CUDA TMA dimensions must be in - // individual FP4 values (not bytes). The caller passes byte-oriented - // dimensions, so we double them. The gmem_outer_stride stays in bytes. - if (not fp4_unpacked_smem) { - gmem_inner_dim *= 2; - smem_inner_dim *= 2; - // outer dim is in rows, not FP4 values — no doubling - // smem_inner_dim will be recalculated from swizzle below if swizzle != 0 - } - } - - if (swizzle_mode != 0) { - if (t.scalar_type() == kPackedFP4 and not fp4_unpacked_smem) { - // For packed FP4: swizzle_mode is in bytes, but smem_inner_dim is in FP4 values - // swizzle_mode / elem_size gives bytes; *2 for FP4 values - smem_inner_dim = (swizzle_mode / elem_size) * 2; - } else { - smem_inner_dim = swizzle_mode / elem_size; - } + // For packed FP4 (mxf4nvf4): use UINT8 TMA instead of 16U4_ALIGN8B. + // The 16U4 TMA type is not widely supported (causes CUDA_ERROR_INVALID_VALUE). + // We load raw bytes via UINT8 and let the UMMA descriptor interpret + // the SMEM layout as packed FP4. Dimensions stay in bytes (like UINT8). } CUtensorMap tensor_map;