[V1] DP scale-out (2/N): Decouple engine process management and comms (#15977)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-05-13 10:48:21 -07:00
committed by GitHub
parent 0b217da646
commit 55aa7af994
10 changed files with 516 additions and 243 deletions

View File

@@ -1,14 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import signal
import uvloop
import vllm.envs as envs
from vllm import AsyncEngineArgs
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.utils import FlexibleArgumentParser
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
class ServeSubcommand(CLISubcommand):
@@ -24,7 +34,10 @@ class ServeSubcommand(CLISubcommand):
if hasattr(args, 'model_tag') and args.model_tag is not None:
args.model = args.model_tag
uvloop.run(run_server(args))
if args.headless:
run_headless(args)
else:
uvloop.run(run_server(args))
def validate(self, args: argparse.Namespace) -> None:
validate_parsed_serve_args(args)
@@ -42,6 +55,18 @@ class ServeSubcommand(CLISubcommand):
nargs='?',
help="The model tag to serve "
"(optional if specified in config)")
serve_parser.add_argument(
"--headless",
action='store_true',
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.")
serve_parser.add_argument(
'--data-parallel-start-rank',
'-dpr',
type=int,
default=0,
help='Starting data parallel rank for secondary nodes.')
serve_parser.add_argument(
"--config",
type=str,
@@ -57,3 +82,55 @@ class ServeSubcommand(CLISubcommand):
def cmd_init() -> list[CLISubcommand]:
return [ServeSubcommand()]
def run_headless(args: argparse.Namespace):
# Create the EngineConfig.
engine_args = AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
if not envs.VLLM_USE_V1:
raise RuntimeError("Headless mode is only supported for V1")
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
input_address = get_tcp_uri(host, port)
if local_engine_count <= 0:
raise RuntimeError("data_parallel_size_local must be > 0 in "
"headless mode")
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
logger.debug("Received %d signal.", signum)
raise SystemExit
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
logger.info(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.", local_engine_count, input_address)
# Create the engines.
engine_manager = CoreEngineProcManager(
target_fn=EngineCoreProc.run_engine_core,
local_engine_count=local_engine_count,
start_index=args.data_parallel_start_rank,
local_start_index=0,
vllm_config=vllm_config,
on_head_node=False,
input_address=input_address,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)
try:
engine_manager.join_first()
finally:
logger.info("Shutting down.")
engine_manager.close()