[misc][cuda] use nvml to avoid accidentally cuda initialization (#6007)

This commit is contained in:
youkaichao
2024-06-30 20:07:34 -07:00
committed by GitHub
parent af9ad46fca
commit 614aa51203
13 changed files with 86 additions and 68 deletions

View File

@@ -11,66 +11,18 @@ from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.parallel_state import is_in_the_same_node
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
try:
import pynvml
# Simulate ImportError if custom_ar ops are not supported.
if not ops.is_custom_op_supported("_C_custom_ar::meta_size"):
raise ImportError("custom_ar", __file__)
assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
custom_ar = True
@contextmanager
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()
except ImportError:
# For AMD GPUs
except Exception:
# For AMD GPUs and CPUs
custom_ar = False
pynvml = None
@contextmanager
def _nvml():
try:
yield
finally:
pass
logger = init_logger(__name__)
@_nvml()
def _is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
so it works on real physical device ids.
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True
def _can_p2p(rank: int, world_size: int) -> bool:
for i in range(world_size):
if i == rank:
@@ -161,7 +113,7 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(physical_device_ids)
full_nvlink = is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"