[DP] Support external DP Load Balancer mode (#19790)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -318,6 +318,7 @@ class EngineArgs:
|
||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||
data_parallel_rank: Optional[int] = None
|
||||
data_parallel_size_local: Optional[int] = None
|
||||
data_parallel_address: Optional[str] = None
|
||||
data_parallel_rpc_port: Optional[int] = None
|
||||
@@ -655,6 +656,12 @@ class EngineArgs:
|
||||
**parallel_kwargs["tensor_parallel_size"])
|
||||
parallel_group.add_argument("--data-parallel-size", "-dp",
|
||||
**parallel_kwargs["data_parallel_size"])
|
||||
parallel_group.add_argument(
|
||||
'--data-parallel-rank',
|
||||
'-dpn',
|
||||
type=int,
|
||||
help='Data parallel rank of this instance. '
|
||||
'When set, enables external load balancer mode.')
|
||||
parallel_group.add_argument('--data-parallel-size-local',
|
||||
'-dpl',
|
||||
type=int,
|
||||
@@ -1126,10 +1133,17 @@ class EngineArgs:
|
||||
# but we should not do this here.
|
||||
placement_group = ray.util.get_current_placement_group()
|
||||
|
||||
# Local DP size defaults to global DP size if not set.
|
||||
data_parallel_size_local = self.data_parallel_size if (
|
||||
self.data_parallel_size_local
|
||||
is None) else self.data_parallel_size_local
|
||||
data_parallel_external_lb = self.data_parallel_rank is not None
|
||||
if data_parallel_external_lb:
|
||||
assert self.data_parallel_size_local in (1, None), (
|
||||
"data_parallel_size_local must be 1 when data_parallel_rank "
|
||||
"is set")
|
||||
data_parallel_size_local = 1
|
||||
elif self.data_parallel_size_local is not None:
|
||||
data_parallel_size_local = self.data_parallel_size_local
|
||||
else:
|
||||
# Local DP size defaults to global DP size if not set.
|
||||
data_parallel_size_local = self.data_parallel_size
|
||||
|
||||
# DP address, used in multi-node case for torch distributed group
|
||||
# and ZMQ sockets.
|
||||
@@ -1154,16 +1168,16 @@ class EngineArgs:
|
||||
self.data_parallel_rpc_port
|
||||
is not None) else ParallelConfig.data_parallel_rpc_port
|
||||
|
||||
data_parallel_backend = self.data_parallel_backend
|
||||
|
||||
parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
data_parallel_size=self.data_parallel_size,
|
||||
data_parallel_rank=self.data_parallel_rank or 0,
|
||||
data_parallel_external_lb=data_parallel_external_lb,
|
||||
data_parallel_size_local=data_parallel_size_local,
|
||||
data_parallel_master_ip=data_parallel_address,
|
||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||
data_parallel_backend=data_parallel_backend,
|
||||
data_parallel_backend=self.data_parallel_backend,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
|
||||
Reference in New Issue
Block a user