Fix version

This commit is contained in:
Chenggang Zhao
2025-10-01 20:31:27 +08:00
parent 07b82fb8cd
commit c1bf4cae4b

View File

@@ -273,10 +273,10 @@ __device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::Tma
__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) {
auto smem_int_desc = __cvta_generic_to_shared(smem_desc);
asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim));
#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 5)))
#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3)))
asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride));
#else
DG_DEVICE_ASSERT(false and "Invalid CUDA version");
DG_STATIC_ASSERT(false, "Invalid CUDA version");
#endif
}