[misc][cuda] use nvml to avoid accidentally cuda initialization (#6007)
This commit is contained in:
@@ -816,6 +816,63 @@ def cuda_device_count_stateless() -> int:
|
||||
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
# all the related functions work on real physical device ids.
|
||||
# the major benefit of using NVML is that it will not initialize CUDA
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
except ImportError:
|
||||
# For non-NV devices
|
||||
pynvml = None
|
||||
|
||||
|
||||
def with_nvml_context(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if pynvml is not None:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
if pynvml is not None:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(device_ids: List[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError as error:
|
||||
logger.error(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped.",
|
||||
exc_info=error)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def get_device_capability_stateless(device_id: int = 0) -> Tuple[int, int]:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
|
||||
|
||||
#From: https://stackoverflow.com/a/4104188/2749989
|
||||
def run_once(f):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user