From 80df24a641156915d0d770acbbd1c3019b80f21d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 22:54:47 +0000 Subject: [PATCH] fix: add kInt8 dtype support to TMA descriptor + change activation tensors to kInt8 - runtime_utils.hpp: added kInt8 -> CU_TENSOR_MAP_DATA_TYPE_UINT8 mapping - mega_nvfp4.hpp: changed activation tensor dtypes from kUInt8 to kInt8 (same byte layout, but kInt8 is recognized by the TMA dtype switch) --- csrc/apis/mega_nvfp4.hpp | 6 +++--- csrc/jit_kernels/impls/runtime_utils.hpp | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/apis/mega_nvfp4.hpp b/csrc/apis/mega_nvfp4.hpp index e7fa81f..f986a28 100644 --- a/csrc/apis/mega_nvfp4.hpp +++ b/csrc/apis/mega_nvfp4.hpp @@ -96,7 +96,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( auto x = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_token_buffer.base)), {num_max_tokens_per_rank, hidden / 2}, // packed: hidden elements = hidden/2 bytes - torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); + torch::TensorOptions().dtype(torch::kInt8).device(buffer.device())); // NVFP4 SF: K/16 bytes per token, packed as K/64 int32 auto x_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_sf_buffer.base)), @@ -114,7 +114,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( auto l1_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_token_buffer.base)), {num_max_pool_tokens, hidden / 2}, // packed: hidden elements = hidden/2 bytes - torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); + torch::TensorOptions().dtype(torch::kInt8).device(buffer.device())); // NVFP4 L1 SF: M-major, K/64 int32 auto l1_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), @@ -125,7 +125,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe( auto l2_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), {num_max_pool_tokens, intermediate_hidden / 2}, // packed: elements/2 bytes - torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device())); + torch::TensorOptions().dtype(torch::kInt8).device(buffer.device())); // NVFP4 L2 SF: M-major, K/64 int32 auto l2_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index 72a76f0..a2fd276 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -82,6 +82,7 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8; + case torch::kInt8: return CU_TENSOR_MAP_DATA_TYPE_UINT8; #if CUDA_VERSION >= 12080 case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B : CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;