[distributed][misc] add specialized method for cuda platform (#7249)
This commit is contained in:
@@ -11,7 +11,8 @@ from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
|||||||
gpu_p2p_access_check)
|
gpu_p2p_access_check)
|
||||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
|
||||||
try:
|
try:
|
||||||
assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
|
assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
|
||||||
@@ -113,7 +114,10 @@ class CustomAllreduce:
|
|||||||
# test nvlink first, this will filter out most of the cases
|
# test nvlink first, this will filter out most of the cases
|
||||||
# where custom allreduce is not supported
|
# where custom allreduce is not supported
|
||||||
# this checks hardware and driver support for NVLink
|
# this checks hardware and driver support for NVLink
|
||||||
full_nvlink = is_full_nvlink(physical_device_ids)
|
assert current_platform.is_cuda()
|
||||||
|
from vllm.platforms.cuda import CudaPlatform
|
||||||
|
cuda_platform: CudaPlatform = current_platform
|
||||||
|
full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
|
||||||
if world_size > 2 and not full_nvlink:
|
if world_size > 2 and not full_nvlink:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Custom allreduce is disabled because it's not supported on"
|
"Custom allreduce is disabled because it's not supported on"
|
||||||
|
|||||||
@@ -4,12 +4,21 @@ pynvml. However, it should not initialize cuda context.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from typing import Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import pynvml
|
import pynvml
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum
|
from .interface import Platform, PlatformEnum
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# NVML utils
|
||||||
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
|
# all the related functions work on real physical device ids.
|
||||||
|
# the major benefit of using NVML is that it will not initialize CUDA
|
||||||
|
|
||||||
|
|
||||||
def with_nvml_context(fn):
|
def with_nvml_context(fn):
|
||||||
|
|
||||||
@@ -47,3 +56,29 @@ class CudaPlatform(Platform):
|
|||||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||||
return get_physical_device_capability(physical_device_id)
|
return get_physical_device_capability(physical_device_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@with_nvml_context
|
||||||
|
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
|
||||||
|
"""
|
||||||
|
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||||
|
"""
|
||||||
|
handles = [
|
||||||
|
pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
|
||||||
|
]
|
||||||
|
for i, handle in enumerate(handles):
|
||||||
|
for j, peer_handle in enumerate(handles):
|
||||||
|
if i < j:
|
||||||
|
try:
|
||||||
|
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||||
|
handle, peer_handle,
|
||||||
|
pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
|
||||||
|
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||||
|
return False
|
||||||
|
except pynvml.NVMLError as error:
|
||||||
|
logger.error(
|
||||||
|
"NVLink detection failed. This is normal if your"
|
||||||
|
" machine has no NVLink equipped.",
|
||||||
|
exc_info=error)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|||||||
@@ -1034,56 +1034,6 @@ def cuda_device_count_stateless() -> int:
|
|||||||
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
|
||||||
# all the related functions work on real physical device ids.
|
|
||||||
# the major benefit of using NVML is that it will not initialize CUDA
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pynvml
|
|
||||||
except ImportError:
|
|
||||||
# For non-NV devices
|
|
||||||
pynvml = None
|
|
||||||
|
|
||||||
|
|
||||||
def with_nvml_context(fn):
|
|
||||||
|
|
||||||
@wraps(fn)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
if pynvml is not None:
|
|
||||||
pynvml.nvmlInit()
|
|
||||||
try:
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
finally:
|
|
||||||
if pynvml is not None:
|
|
||||||
pynvml.nvmlShutdown()
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
@with_nvml_context
|
|
||||||
def is_full_nvlink(device_ids: List[int]) -> bool:
|
|
||||||
"""
|
|
||||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
|
||||||
"""
|
|
||||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
|
|
||||||
for i, handle in enumerate(handles):
|
|
||||||
for j, peer_handle in enumerate(handles):
|
|
||||||
if i < j:
|
|
||||||
try:
|
|
||||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
|
||||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
|
|
||||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
|
||||||
return False
|
|
||||||
except pynvml.NVMLError as error:
|
|
||||||
logger.error(
|
|
||||||
"NVLink detection failed. This is normal if your"
|
|
||||||
" machine has no NVLink equipped.",
|
|
||||||
exc_info=error)
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
#From: https://stackoverflow.com/a/4104188/2749989
|
#From: https://stackoverflow.com/a/4104188/2749989
|
||||||
def run_once(f):
|
def run_once(f):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user