From 07b82fb8cd035330b8849ebc42ef9d6df2dfdde4 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 1 Oct 2025 20:29:15 +0800 Subject: [PATCH] Fix old CUDA compatibility --- csrc/jit_kernels/impls/runtime_utils.hpp | 3 +++ deep_gemm/include/deep_gemm/common/sm90_utils.cuh | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index 92685ef..3dc5229 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -68,11 +68,14 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& } static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) { +#if CUDA_VERSION >= 12080 if (base != 0) { DG_HOST_ASSERT(base == 32 and mode == 128); return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B; } +#endif + DG_HOST_ASSERT(base == 0); switch (mode) { case 0: case 16: return CU_TENSOR_MAP_SWIZZLE_NONE; diff --git a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh index d587f42..863d225 100644 --- a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -276,7 +276,7 @@ __device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_sm #if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 5))) asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); #else - DG_STATIC_ASSERT(false, "Invalid CUDA version") + DG_DEVICE_ASSERT(false and "Invalid CUDA version"); #endif }