fix: add kInt8 dtype support to TMA descriptor + change activation tensors to kInt8

- runtime_utils.hpp: added kInt8 -> CU_TENSOR_MAP_DATA_TYPE_UINT8 mapping
- mega_nvfp4.hpp: changed activation tensor dtypes from kUInt8 to kInt8
  (same byte layout, but kInt8 is recognized by the TMA dtype switch)
This commit is contained in:
2026-05-11 22:54:47 +00:00
parent e608a20dec
commit 80df24a641
2 changed files with 4 additions and 3 deletions

View File

@@ -96,7 +96,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
{num_max_tokens_per_rank, hidden / 2}, // packed: hidden elements = hidden/2 bytes
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
torch::TensorOptions().dtype(torch::kInt8).device(buffer.device()));
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
auto x_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
@@ -114,7 +114,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
auto l1_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
{num_max_pool_tokens, hidden / 2}, // packed: hidden elements = hidden/2 bytes
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
torch::TensorOptions().dtype(torch::kInt8).device(buffer.device()));
// NVFP4 L1 SF: M-major, K/64 int32
auto l1_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
@@ -125,7 +125,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
{num_max_pool_tokens, intermediate_hidden / 2}, // packed: elements/2 bytes
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
torch::TensorOptions().dtype(torch::kInt8).device(buffer.device()));
// NVFP4 L2 SF: M-major, K/64 int32
auto l2_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),

View File

@@ -82,6 +82,7 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
case torch::kInt8: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
#if CUDA_VERSION >= 12080
case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B
: CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;