[V1] Support DP with Ray (#18779)
This commit is contained in:
@@ -39,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
GiB_bytes, is_in_ray_actor)
|
||||
GiB_bytes, get_ip, is_in_ray_actor)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -292,6 +292,7 @@ class EngineArgs:
|
||||
data_parallel_size_local: Optional[int] = None
|
||||
data_parallel_address: Optional[str] = None
|
||||
data_parallel_rpc_port: Optional[int] = None
|
||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
max_parallel_loading_workers: Optional[
|
||||
int] = ParallelConfig.max_parallel_loading_workers
|
||||
@@ -624,6 +625,12 @@ class EngineArgs:
|
||||
type=int,
|
||||
help='Port for data parallel RPC '
|
||||
'communication.')
|
||||
parallel_group.add_argument('--data-parallel-backend',
|
||||
'-dpb',
|
||||
type=str,
|
||||
default='mp',
|
||||
help='Backend for data parallel, either '
|
||||
'"mp" or "ray".')
|
||||
parallel_group.add_argument(
|
||||
"--enable-expert-parallel",
|
||||
**parallel_kwargs["enable_expert_parallel"])
|
||||
@@ -1059,9 +1066,20 @@ class EngineArgs:
|
||||
|
||||
# DP address, used in multi-node case for torch distributed group
|
||||
# and ZMQ sockets.
|
||||
data_parallel_address = self.data_parallel_address if (
|
||||
self.data_parallel_address
|
||||
is not None) else ParallelConfig.data_parallel_master_ip
|
||||
if self.data_parallel_address is None:
|
||||
if self.data_parallel_backend == "ray":
|
||||
host_ip = get_ip()
|
||||
logger.info(
|
||||
"Using host IP %s as ray-based data parallel address",
|
||||
host_ip)
|
||||
data_parallel_address = host_ip
|
||||
else:
|
||||
assert self.data_parallel_backend == "mp", (
|
||||
"data_parallel_backend can only be ray or mp, got %s",
|
||||
self.data_parallel_backend)
|
||||
data_parallel_address = ParallelConfig.data_parallel_master_ip
|
||||
else:
|
||||
data_parallel_address = self.data_parallel_address
|
||||
|
||||
# This port is only used when there are remote data parallel engines,
|
||||
# otherwise the local IPC transport is used.
|
||||
@@ -1069,6 +1087,8 @@ class EngineArgs:
|
||||
self.data_parallel_rpc_port
|
||||
is not None) else ParallelConfig.data_parallel_rpc_port
|
||||
|
||||
data_parallel_backend = self.data_parallel_backend
|
||||
|
||||
parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
@@ -1076,6 +1096,7 @@ class EngineArgs:
|
||||
data_parallel_size_local=data_parallel_size_local,
|
||||
data_parallel_master_ip=data_parallel_address,
|
||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||
data_parallel_backend=data_parallel_backend,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||
|
||||
Reference in New Issue
Block a user