[Frontend] Move APIServerProcessManager target server fn (#38115)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -89,9 +89,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update):
|
||||
assert not proc.is_alive()
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.entrypoints.cli.serve.run_api_server_worker_proc", mock_run_api_server_worker
|
||||
)
|
||||
@patch("vllm.v1.utils.run_api_server_worker_proc", mock_run_api_server_worker)
|
||||
def test_wait_for_completion_or_failure(api_server_args):
|
||||
"""Test that wait_for_completion_or_failure works with failures."""
|
||||
global WORKER_RUNTIME_SECONDS
|
||||
|
||||
@@ -10,18 +10,13 @@ import uvloop
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
run_server,
|
||||
run_server_worker,
|
||||
setup_server,
|
||||
)
|
||||
from vllm.entrypoints.openai.api_server import run_server, setup_server
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.network_utils import get_tcp_uri
|
||||
from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
@@ -293,7 +288,6 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
) as (local_engine_manager, coordinator, addresses, tensor_queue):
|
||||
# Construct common args for the APIServerProcessManager up-front.
|
||||
api_server_manager_kwargs = dict(
|
||||
target_server_fn=run_api_server_worker_proc,
|
||||
listen_address=listen_address,
|
||||
sock=sock,
|
||||
args=args,
|
||||
@@ -346,19 +340,3 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
local_engine_manager.shutdown(timeout=to_timeout(shutdown_by))
|
||||
if coordinator:
|
||||
coordinator.shutdown(timeout=to_timeout(shutdown_by))
|
||||
|
||||
|
||||
def run_api_server_worker_proc(
|
||||
listen_address, sock, args, client_config=None, **uvicorn_kwargs
|
||||
) -> None:
|
||||
"""Entrypoint for individual API server worker processes."""
|
||||
client_config = client_config or {}
|
||||
server_index = client_config.get("client_index", 0)
|
||||
|
||||
# Set process title and add process-specific prefix to stdout and stderr.
|
||||
set_process_title("APIServer", str(server_index))
|
||||
decorate_logs()
|
||||
|
||||
uvloop.run(
|
||||
run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs)
|
||||
)
|
||||
|
||||
@@ -21,13 +21,14 @@ from typing import (
|
||||
)
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from torch.autograd.profiler import record_function
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
|
||||
from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri
|
||||
from vllm.utils.system_utils import kill_process_tree
|
||||
from vllm.utils.system_utils import decorate_logs, kill_process_tree, set_process_title
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -166,20 +167,20 @@ class APIServerProcessManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_server_fn: Callable,
|
||||
listen_address: str,
|
||||
sock: Any,
|
||||
args: argparse.Namespace,
|
||||
num_servers: int,
|
||||
input_addresses: list[str],
|
||||
output_addresses: list[str],
|
||||
target_server_fn: Callable | None = None,
|
||||
stats_update_address: str | None = None,
|
||||
tensor_queue: Queue | None = None,
|
||||
):
|
||||
"""Initialize and start API server worker processes.
|
||||
|
||||
Args:
|
||||
target_server_fn: Function to call for each API server process
|
||||
target_server_fn: Override function to call for each API server process
|
||||
listen_address: Address to listen for client connections
|
||||
sock: Socket for client connections
|
||||
args: Command line arguments
|
||||
@@ -212,7 +213,7 @@ class APIServerProcessManager:
|
||||
client_config["tensor_queue"] = tensor_queue
|
||||
|
||||
proc = spawn_context.Process(
|
||||
target=target_server_fn,
|
||||
target=target_server_fn or run_api_server_worker_proc,
|
||||
name=f"ApiServer_{i}",
|
||||
args=(listen_address, sock, args, client_config),
|
||||
)
|
||||
@@ -231,6 +232,25 @@ class APIServerProcessManager:
|
||||
shutdown(self.processes, timeout=timeout)
|
||||
|
||||
|
||||
def run_api_server_worker_proc(
|
||||
listen_address, sock, args, client_config=None, **uvicorn_kwargs
|
||||
) -> None:
|
||||
"""Entrypoint for individual API server worker processes."""
|
||||
|
||||
from vllm.entrypoints.openai.api_server import run_server_worker
|
||||
|
||||
client_config = client_config or {}
|
||||
server_index = client_config.get("client_index", 0)
|
||||
|
||||
# Set process title and add process-specific prefix to stdout and stderr.
|
||||
set_process_title("APIServer", str(server_index))
|
||||
decorate_logs()
|
||||
|
||||
uvloop.run(
|
||||
run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs)
|
||||
)
|
||||
|
||||
|
||||
def wait_for_completion_or_failure(
|
||||
api_server_manager: APIServerProcessManager,
|
||||
engine_manager: Union["CoreEngineProcManager", "CoreEngineActorManager"]
|
||||
|
||||
Reference in New Issue
Block a user