[V1] Support DP with Ray (#18779)

This commit is contained in:
Rui Qiao
2025-06-02 21:15:13 -07:00
committed by GitHub
parent 9e6f61e8c3
commit bdce64f236
10 changed files with 551 additions and 120 deletions

View File

@@ -39,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, is_in_ray_actor)
GiB_bytes, get_ip, is_in_ray_actor)
# yapf: enable
@@ -292,6 +292,7 @@ class EngineArgs:
data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
@@ -624,6 +625,12 @@ class EngineArgs:
type=int,
help='Port for data parallel RPC '
'communication.')
parallel_group.add_argument('--data-parallel-backend',
'-dpb',
type=str,
default='mp',
help='Backend for data parallel, either '
'"mp" or "ray".')
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
@@ -1059,9 +1066,20 @@ class EngineArgs:
# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.
data_parallel_address = self.data_parallel_address if (
self.data_parallel_address
is not None) else ParallelConfig.data_parallel_master_ip
if self.data_parallel_address is None:
if self.data_parallel_backend == "ray":
host_ip = get_ip()
logger.info(
"Using host IP %s as ray-based data parallel address",
host_ip)
data_parallel_address = host_ip
else:
assert self.data_parallel_backend == "mp", (
"data_parallel_backend can only be ray or mp, got %s",
self.data_parallel_backend)
data_parallel_address = ParallelConfig.data_parallel_master_ip
else:
data_parallel_address = self.data_parallel_address
# This port is only used when there are remote data parallel engines,
# otherwise the local IPC transport is used.
@@ -1069,6 +1087,8 @@ class EngineArgs:
self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port
data_parallel_backend = self.data_parallel_backend
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
@@ -1076,6 +1096,7 @@ class EngineArgs:
data_parallel_size_local=data_parallel_size_local,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=data_parallel_backend,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,