diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index 1875d54..e05cf92 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -8,7 +8,7 @@ namespace deep_gemm { -#if CUDART_VERSION >= 12080 and not defined(DG_JIT_USE_DRIVER_API) +#if CUDART_VERSION >= 12080 and defined(DG_JIT_USE_RUNTIME_API) // Use CUDA runtime API using LibraryHandle = cudaLibrary_t; diff --git a/setup.py b/setup.py index 13f012f..1f29ad0 100644 --- a/setup.py +++ b/setup.py @@ -28,9 +28,9 @@ third_party_include_dirs = [ 'third-party/cutlass/include/cutlass', ] -# Use driver API for older CUDA compatibility -if int(os.environ.get('DG_JIT_USE_DRIVER_API', '0')): - cxx_flags.append('-DDG_JIT_USE_DRIVER_API') +# Use runtime API +if int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')): + cxx_flags.append('-DDG_JIT_USE_RUNTIME_API') class CustomBuildPy(build_py):