[V1] Support DP with Ray (#18779)
This commit is contained in:
269
vllm/v1/utils.py
269
vllm/v1/utils.py
@@ -27,6 +27,8 @@ from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
|
||||
@@ -112,6 +114,45 @@ def get_engine_client_zmq_addr(local_only: bool,
|
||||
host, port or get_open_port()))
|
||||
|
||||
|
||||
class CoreEngineState(Enum):
|
||||
NEW = auto()
|
||||
CONNECTED = auto()
|
||||
READY = auto()
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
|
||||
def __init__(self, index: int = 0, local: bool = True):
|
||||
self.local = local
|
||||
self.index = index
|
||||
self.identity = index.to_bytes(2, "little")
|
||||
|
||||
self.state = CoreEngineState.NEW
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineZmqAddresses:
|
||||
# ZMQ input socket addresses for each front-end client (requests)
|
||||
inputs: list[str]
|
||||
# ZMQ output socket addresses for each front-end client (responses)
|
||||
outputs: list[str]
|
||||
# ZMQ input socket address of DP coordinator if applicable
|
||||
coordinator_input: Optional[str] = None
|
||||
# ZMQ output socket address of DP coordinator if applicable
|
||||
coordinator_output: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineHandshakeMetadata:
|
||||
"""Metadata sent to each engine process during startup handshake,
|
||||
including addresses of the front-end ZMQ queues that they should
|
||||
connect to.
|
||||
"""
|
||||
addresses: EngineZmqAddresses
|
||||
parallel_config: dict[str, Union[int, str]]
|
||||
|
||||
|
||||
class APIServerProcessManager:
|
||||
"""Manages a group of API server processes.
|
||||
|
||||
@@ -245,43 +286,168 @@ class CoreEngineProcManager:
|
||||
}
|
||||
|
||||
|
||||
class CoreEngineState(Enum):
|
||||
NEW = auto()
|
||||
CONNECTED = auto()
|
||||
READY = auto()
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
|
||||
def __init__(self, index: int = 0, local: bool = True):
|
||||
self.local = local
|
||||
self.index = index
|
||||
self.identity = index.to_bytes(2, "little")
|
||||
|
||||
self.state = CoreEngineState.NEW
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineZmqAddresses:
|
||||
# ZMQ input socket addresses for each front-end client (requests)
|
||||
inputs: list[str]
|
||||
# ZMQ output socket addresses for each front-end client (responses)
|
||||
outputs: list[str]
|
||||
# ZMQ input socket address of DP coordinator if applicable
|
||||
coordinator_input: Optional[str] = None
|
||||
# ZMQ output socket address of DP coordinator if applicable
|
||||
coordinator_output: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineHandshakeMetadata:
|
||||
"""Metadata sent to each engine process during startup handshake,
|
||||
including addresses of the front-end ZMQ queues that they should
|
||||
connect to.
|
||||
class CoreEngineActorManager:
|
||||
"""
|
||||
addresses: EngineZmqAddresses
|
||||
parallel_config: dict[str, Union[int, str]]
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
of core engine Ray actors used by the AsyncLLM and LLMEngine.
|
||||
|
||||
Different from CoreEngineProcManager, this class manages
|
||||
core engines for both local and remote nodes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
addresses: EngineZmqAddresses,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
placement_groups: Optional[list["PlacementGroup"]] = None,
|
||||
local_dp_ranks: Optional[list[int]] = None,
|
||||
):
|
||||
import copy
|
||||
|
||||
import ray
|
||||
from ray.util.scheduling_strategies import (
|
||||
PlacementGroupSchedulingStrategy)
|
||||
|
||||
from vllm.v1.engine.core import DPEngineCoreActor
|
||||
|
||||
self.local_engine_actors: list[ray.ActorHandle] = []
|
||||
self.remote_engine_actors: list[ray.ActorHandle] = []
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
|
||||
if ray.is_initialized():
|
||||
logger.info(
|
||||
"Ray is already initialized. Skipping Ray initialization.")
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
if placement_groups is not None:
|
||||
assert local_dp_ranks is not None, (
|
||||
"local_dp_ranks must be provided if "
|
||||
"placement_groups is provided")
|
||||
assert len(placement_groups) == len(local_dp_ranks), (
|
||||
"placement_groups and local_dp_ranks must "
|
||||
"have the same length")
|
||||
logger.info("Using provided placement groups")
|
||||
# TODO(rui): validate passed-in placement groups
|
||||
self.created_placement_groups = []
|
||||
else:
|
||||
placement_groups, local_dp_ranks = \
|
||||
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
|
||||
self.created_placement_groups = placement_groups
|
||||
assert len(placement_groups) == dp_size, (
|
||||
"Number of placement groups must match data parallel size")
|
||||
|
||||
refs = []
|
||||
for index in range(dp_size):
|
||||
local_index = local_dp_ranks[index]
|
||||
dp_vllm_config = copy.deepcopy(vllm_config)
|
||||
pg = placement_groups[index]
|
||||
dp_vllm_config.parallel_config.placement_group = pg
|
||||
on_head_node = index < local_engine_count
|
||||
actor = ray.remote(DPEngineCoreActor).options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_bundle_index=world_size,
|
||||
)).remote(vllm_config=dp_vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
on_head_node=on_head_node,
|
||||
addresses=addresses,
|
||||
dp_rank=index,
|
||||
local_dp_rank=local_index)
|
||||
if on_head_node:
|
||||
self.local_engine_actors.append(actor)
|
||||
else:
|
||||
self.remote_engine_actors.append(actor)
|
||||
refs.append(actor.wait_for_init.remote())
|
||||
|
||||
ray.get(refs)
|
||||
self.run_refs = []
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
self.run_refs.append(actor.run.remote())
|
||||
|
||||
@staticmethod
|
||||
def create_dp_placement_groups(
|
||||
vllm_config: VllmConfig
|
||||
) -> tuple[list["PlacementGroup"], list[int]]:
|
||||
|
||||
import ray
|
||||
from ray._private.state import available_resources_per_node
|
||||
from ray.util.state import list_nodes
|
||||
|
||||
logger.info("Creating placement groups for data parallel")
|
||||
dp_master_ip = \
|
||||
vllm_config.parallel_config.data_parallel_master_ip
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
|
||||
nodes = list_nodes()
|
||||
nodes = sorted(list_nodes(),
|
||||
key=lambda node: node.node_ip != dp_master_ip)
|
||||
assert nodes[0].node_ip == dp_master_ip, (
|
||||
"The first node must be the head node")
|
||||
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
|
||||
"There can only be one head node")
|
||||
|
||||
available_resources = available_resources_per_node()
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
placement_groups: list[PlacementGroup] = []
|
||||
local_dp_ranks: list[int] = []
|
||||
|
||||
for node in nodes:
|
||||
node_ip = node.node_ip
|
||||
node_resources = available_resources[node.node_id]
|
||||
# For now, each DP rank can only be assigned to one node
|
||||
# TODO(rui): support allocating a single DP rank
|
||||
# to multiple nodes
|
||||
available_engine_count = node_resources["GPU"] // world_size
|
||||
if node_ip == dp_master_ip:
|
||||
assert available_engine_count >= local_engine_count, (
|
||||
"Not enough resources to allocate DP ranks "
|
||||
f"on DP master node {node_ip}")
|
||||
for i in range(local_engine_count):
|
||||
bundles = [{
|
||||
"GPU": 1.0,
|
||||
"node:" + dp_master_ip: 0.001
|
||||
}] * world_size + [{
|
||||
"CPU": 1.0
|
||||
}]
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{len(placement_groups)}",
|
||||
strategy="STRICT_PACK",
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
local_dp_ranks.append(i)
|
||||
else:
|
||||
for i in range(available_engine_count):
|
||||
if len(placement_groups) == dp_size:
|
||||
break
|
||||
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{len(placement_groups)}",
|
||||
strategy="STRICT_PACK",
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
local_dp_ranks.append(i)
|
||||
return placement_groups, local_dp_ranks
|
||||
|
||||
def get_run_refs(self):
|
||||
return self.run_refs
|
||||
|
||||
def close(self):
|
||||
import ray
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
ray.kill(actor)
|
||||
for pg in self.created_placement_groups:
|
||||
ray.util.remove_placement_group(pg)
|
||||
|
||||
|
||||
def wait_for_engine_startup(
|
||||
@@ -383,11 +549,19 @@ def wait_for_engine_startup(
|
||||
|
||||
def wait_for_completion_or_failure(
|
||||
api_server_manager: APIServerProcessManager,
|
||||
local_engine_manager: Optional[CoreEngineProcManager] = None,
|
||||
engine_manager: Optional[Union[CoreEngineProcManager,
|
||||
CoreEngineActorManager]] = None,
|
||||
coordinator: Optional["DPCoordinator"] = None) -> None:
|
||||
"""Wait for all processes to complete or detect if any fail.
|
||||
|
||||
Raises an exception if any process exits with a non-zero status.
|
||||
|
||||
Args:
|
||||
api_server_manager: The manager for API servers.
|
||||
engine_manager: The manager for engine processes.
|
||||
If CoreEngineProcManager, it manages local engines;
|
||||
if CoreEngineActorManager, it manages all engines.
|
||||
coordinator: The coordinator for data parallel.
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -402,14 +576,18 @@ def wait_for_completion_or_failure(
|
||||
if coordinator:
|
||||
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
|
||||
|
||||
if local_engine_manager:
|
||||
for proc in local_engine_manager.processes:
|
||||
actor_run_refs = []
|
||||
if isinstance(engine_manager, CoreEngineProcManager):
|
||||
for proc in engine_manager.processes:
|
||||
sentinel_to_proc[proc.sentinel] = proc
|
||||
elif isinstance(engine_manager, CoreEngineActorManager):
|
||||
actor_run_refs = engine_manager.get_run_refs()
|
||||
|
||||
# Check if any process terminates
|
||||
while sentinel_to_proc:
|
||||
while sentinel_to_proc or actor_run_refs:
|
||||
# Wait for any process to terminate
|
||||
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
|
||||
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
|
||||
timeout=5)
|
||||
|
||||
# Process any terminated processes
|
||||
for sentinel in ready_sentinels:
|
||||
@@ -420,6 +598,11 @@ def wait_for_completion_or_failure(
|
||||
raise RuntimeError(
|
||||
f"Process {proc.name} (PID: {proc.pid}) "
|
||||
f"died with exit code {proc.exitcode}")
|
||||
|
||||
if actor_run_refs:
|
||||
import ray
|
||||
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received KeyboardInterrupt, shutting down API servers...")
|
||||
except Exception as e:
|
||||
@@ -431,8 +614,8 @@ def wait_for_completion_or_failure(
|
||||
api_server_manager.close()
|
||||
if coordinator:
|
||||
coordinator.close()
|
||||
if local_engine_manager:
|
||||
local_engine_manager.close()
|
||||
if engine_manager:
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
# Note(rob): shutdown function cannot be a bound method,
|
||||
|
||||
Reference in New Issue
Block a user