Use CUDA runtime API to get device prop instead of ATen

This commit is contained in:
Chenggang Zhao
2025-10-11 09:14:00 +08:00
parent 9f196058ae
commit f8f41145da

View File

@@ -36,8 +36,13 @@ public:
}
std::shared_ptr<cudaDeviceProp> get_prop() {
if (cached_prop == nullptr)
cached_prop = std::make_shared<cudaDeviceProp>(*at::cuda::getCurrentDeviceProperties());
if (cached_prop == nullptr) {
int device_idx;
cudaDeviceProp prop;
DG_CUDA_RUNTIME_CHECK(cudaGetDevice(&device_idx));
DG_CUDA_RUNTIME_CHECK(cudaGetDeviceProperties(&prop, device_idx));
cached_prop = std::make_shared<cudaDeviceProp>(prop);
}
return cached_prop;
}