Use CUDA runtime API to get device prop instead of ATen
This commit is contained in:
@@ -36,8 +36,13 @@ public:
|
||||
}
|
||||
|
||||
std::shared_ptr<cudaDeviceProp> get_prop() {
|
||||
if (cached_prop == nullptr)
|
||||
cached_prop = std::make_shared<cudaDeviceProp>(*at::cuda::getCurrentDeviceProperties());
|
||||
if (cached_prop == nullptr) {
|
||||
int device_idx;
|
||||
cudaDeviceProp prop;
|
||||
DG_CUDA_RUNTIME_CHECK(cudaGetDevice(&device_idx));
|
||||
DG_CUDA_RUNTIME_CHECK(cudaGetDeviceProperties(&prop, device_idx));
|
||||
cached_prop = std::make_shared<cudaDeviceProp>(prop);
|
||||
}
|
||||
return cached_prop;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user