fix: remove duplicate kInt8 case — kPackedFP4 is already kInt8
kPackedFP4 = torch::kInt8, so the kInt8 case was a duplicate. The real fix was in mega_nvfp4.hpp: changing kUInt8→kInt8 so tensors match the existing kPackedFP4 path in the TMA switch.
This commit is contained in:
@@ -82,7 +82,6 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
|
||||
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
case torch::kInt8: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
#if CUDA_VERSION >= 12080
|
||||
case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B
|
||||
: CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;
|
||||
|
||||
Reference in New Issue
Block a user