fix: add kInt8 dtype support to TMA descriptor + change activation tensors to kInt8
- runtime_utils.hpp: added kInt8 -> CU_TENSOR_MAP_DATA_TYPE_UINT8 mapping - mega_nvfp4.hpp: changed activation tensor dtypes from kUInt8 to kInt8 (same byte layout, but kInt8 is recognized by the TMA dtype switch)
This commit is contained in:
@@ -96,7 +96,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
auto x = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
|
||||
{num_max_tokens_per_rank, hidden / 2}, // packed: hidden elements = hidden/2 bytes
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
|
||||
torch::TensorOptions().dtype(torch::kInt8).device(buffer.device()));
|
||||
// NVFP4 SF: K/16 bytes per token, packed as K/64 int32
|
||||
auto x_sf = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
|
||||
@@ -114,7 +114,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
auto l1_acts = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
|
||||
{num_max_pool_tokens, hidden / 2}, // packed: hidden elements = hidden/2 bytes
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
|
||||
torch::TensorOptions().dtype(torch::kInt8).device(buffer.device()));
|
||||
// NVFP4 L1 SF: M-major, K/64 int32
|
||||
auto l1_acts_sf = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
|
||||
@@ -125,7 +125,7 @@ get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
auto l2_acts = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
|
||||
{num_max_pool_tokens, intermediate_hidden / 2}, // packed: elements/2 bytes
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(buffer.device()));
|
||||
torch::TensorOptions().dtype(torch::kInt8).device(buffer.device()));
|
||||
// NVFP4 L2 SF: M-major, K/64 int32
|
||||
auto l2_acts_sf = torch::from_blob(
|
||||
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
|
||||
|
||||
@@ -82,6 +82,7 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
|
||||
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
case torch::kInt8: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
#if CUDA_VERSION >= 12080
|
||||
case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B
|
||||
: CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;
|
||||
|
||||
Reference in New Issue
Block a user