diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp index 6ffd26f..dc207c4 100644 --- a/csrc/jit/device_runtime.hpp +++ b/csrc/jit/device_runtime.hpp @@ -36,8 +36,13 @@ public: } std::shared_ptr get_prop() { - if (cached_prop == nullptr) - cached_prop = std::make_shared(*at::cuda::getCurrentDeviceProperties()); + if (cached_prop == nullptr) { + int device_idx; + cudaDeviceProp prop; + DG_CUDA_RUNTIME_CHECK(cudaGetDevice(&device_idx)); + DG_CUDA_RUNTIME_CHECK(cudaGetDeviceProperties(&prop, device_idx)); + cached_prop = std::make_shared(prop); + } return cached_prop; }