[misc][cuda] use nvml to avoid accidentally cuda initialization (#6007)
This commit is contained in:
@@ -11,66 +11,18 @@ from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import is_in_the_same_node
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
# Simulate ImportError if custom_ar ops are not supported.
|
||||
if not ops.is_custom_op_supported("_C_custom_ar::meta_size"):
|
||||
raise ImportError("custom_ar", __file__)
|
||||
|
||||
assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
|
||||
custom_ar = True
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
yield
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
except ImportError:
|
||||
# For AMD GPUs
|
||||
except Exception:
|
||||
# For AMD GPUs and CPUs
|
||||
custom_ar = False
|
||||
pynvml = None
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@_nvml()
|
||||
def _is_full_nvlink(device_ids: List[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
so it works on real physical device ids.
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError as error:
|
||||
logger.error(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped.",
|
||||
exc_info=error)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
@@ -161,7 +113,7 @@ class CustomAllreduce:
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
full_nvlink = _is_full_nvlink(physical_device_ids)
|
||||
full_nvlink = is_full_nvlink(physical_device_ids)
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
|
||||
Reference in New Issue
Block a user