[distributed][misc] add specialized method for cuda platform (#7249)
This commit is contained in:
@@ -11,7 +11,8 @@ from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
try:
|
||||
assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
|
||||
@@ -113,7 +114,10 @@ class CustomAllreduce:
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
full_nvlink = is_full_nvlink(physical_device_ids)
|
||||
assert current_platform.is_cuda()
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
cuda_platform: CudaPlatform = current_platform
|
||||
full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
|
||||
Reference in New Issue
Block a user