[core] set up data parallel communication (#13591)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-02-22 19:28:59 +08:00
committed by GitHub
parent 7f6bae561c
commit 3e472d882a
17 changed files with 416 additions and 28 deletions

View File

@@ -16,6 +16,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
import torch
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
import vllm.envs as envs
@@ -1296,6 +1297,11 @@ class ParallelConfig:
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
data_parallel_size: int = 1 # Number of data parallel groups.
data_parallel_rank: int = 0 # Rank of the data parallel group.
# IP of the data parallel master.
data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master.
# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor
@@ -1329,10 +1335,55 @@ class ParallelConfig:
worker_cls: str = "auto"
sd_worker_cls: str = "auto"
# world_size is TPxPP, it affects the number of workers we create.
world_size: int = field(init=False)
# world_size_across_dp is TPxPPxDP, it is the size of the world
# including data parallelism.
world_size_across_dp: int = field(init=False)
rank: int = 0
def get_next_dp_init_port(self) -> int:
"""
We might need to initialize process groups in multiple
processes that is related to data parallelism,
e.g. both in the worker and in the engine, which
can live in different processes. To avoid port conflicts, we
increment the port number each time we need to initialize a
new process group related to data parallelism.
"""
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
return answer
def stateless_init_dp_group(self) -> "ProcessGroup":
from vllm.distributed.utils import (
stateless_init_torch_distributed_process_group)
# use gloo since the engine process might not have cuda device
dp_group = stateless_init_torch_distributed_process_group(
self.data_parallel_master_ip,
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend="gloo")
return dp_group
@staticmethod
def has_unfinished_dp(dp_group: "ProcessGroup",
has_unfinished: bool) -> bool:
tensor = torch.tensor([has_unfinished],
dtype=torch.int32,
device="cpu")
# dp rank 0: has_unfinished_seqs=True
# dp rank 1: has_unfinished_seqs=False
# aggregated: has_unfinished_seqs=True
# so this is an OR operation, i.e. MAX in integers
torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
aggregated_has_unfinished = bool(tensor.item())
return aggregated_has_unfinished
def compute_hash(self):
"""
Provide a hash that uniquely identifies all the configs
@@ -1350,6 +1401,12 @@ class ParallelConfig:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size
self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
self.world_size_across_dp = self.world_size * self.data_parallel_size
ray_only_devices = ["tpu"]
from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices