[V1] Support MP Executor for multi node distributed inference (#23691)

Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Lucia Fang
2025-11-16 01:01:21 -08:00
committed by GitHub
parent a55b64635c
commit b316ac6589
10 changed files with 930 additions and 82 deletions

View File

@@ -210,6 +210,18 @@ class ParallelConfig:
class is dynamically inherited by the worker class. This is used to inject
new attributes and methods to the worker class for use in collective_rpc
calls."""
master_addr: str = "127.0.0.1"
"""distributed master address for multi-node distributed
inference when distributed_executor_backend is mp."""
master_port: int = 29501
"""distributed master port for multi-node distributed
inference when distributed_executor_backend is mp."""
node_rank: int = 0
"""distributed node rank for multi-node distributed
inference when distributed_executor_backend is mp."""
nnodes: int = 1
"""num of nodes for multi-node distributed
inference when distributed_executor_backend is mp."""
world_size: int = Field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""
@@ -387,6 +399,23 @@ class ParallelConfig:
and self.data_parallel_size > 1
)
@property
def node_rank_within_dp(self) -> int:
return self.node_rank % self.nnodes_within_dp
@property
def nnodes_within_dp(self) -> int:
if self.nnodes == 1:
return 1
data_parallel_node_size = (
self.data_parallel_size // self.data_parallel_size_local
)
return self.nnodes // data_parallel_node_size
@property
def local_world_size(self) -> int:
return self.world_size // self.nnodes_within_dp
@staticmethod
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
@@ -528,6 +557,8 @@ class ParallelConfig:
ray_found = ray_utils.ray_is_available()
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
backend = "uni"
elif current_platform.is_cuda() and self.nnodes > 1:
backend = "mp"
elif (
current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size
@@ -565,6 +596,10 @@ class ParallelConfig:
"max_parallel_loading_workers is currently "
"not supported and will be ignored."
)
if self.distributed_executor_backend != "mp" and self.nnodes > 1:
raise ValueError(
"nnodes > 1 can only be set when distributed exectuor backend is mp."
)
@property
def use_ray(self) -> bool:
@@ -607,6 +642,11 @@ class ParallelConfig:
"Disabled the custom all-reduce kernel because it is not "
"supported on current platform."
)
if self.nnodes > 1:
self.disable_custom_all_reduce = True
logger.debug(
"Disabled the custom all-reduce since we are running on multi-node."
)
if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError(
"Unable to use nsight profiling unless workers run with Ray."