Fix old CUDA compatibility
This commit is contained in:
@@ -68,11 +68,14 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
|
||||
}
|
||||
|
||||
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) {
|
||||
#if CUDA_VERSION >= 12080
|
||||
if (base != 0) {
|
||||
DG_HOST_ASSERT(base == 32 and mode == 128);
|
||||
return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
|
||||
}
|
||||
#endif
|
||||
|
||||
DG_HOST_ASSERT(base == 0);
|
||||
switch (mode) {
|
||||
case 0:
|
||||
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
|
||||
@@ -276,7 +276,7 @@ __device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_sm
|
||||
#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 5)))
|
||||
asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride));
|
||||
#else
|
||||
DG_STATIC_ASSERT(false, "Invalid CUDA version")
|
||||
DG_DEVICE_ASSERT(false and "Invalid CUDA version");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user