Make engine core client handshake timeout configurable (#27444)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
This commit is contained in:
@@ -2,12 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -24,7 +26,11 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient
|
||||
from vllm.v1.engine.core_client import (
|
||||
AsyncMPClient,
|
||||
EngineCoreClient,
|
||||
SyncMPClient,
|
||||
)
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
@@ -60,6 +66,91 @@ def make_request(
|
||||
)
|
||||
|
||||
|
||||
def _reload_envs_module():
|
||||
import vllm.envs as envs_mod
|
||||
|
||||
cache_clear = getattr(getattr(envs_mod, "__getattr__", None), "cache_clear", None)
|
||||
if cache_clear is not None:
|
||||
cache_clear()
|
||||
return importlib.reload(envs_mod)
|
||||
|
||||
|
||||
def _reload_core_client_module():
|
||||
module = importlib.import_module("vllm.v1.engine.core_client")
|
||||
return importlib.reload(module)
|
||||
|
||||
|
||||
def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch):
|
||||
timeout_value = 654
|
||||
monkeypatch.setenv("VLLM_ENGINE_READY_TIMEOUT_S", str(timeout_value))
|
||||
|
||||
# Ensure that the environment variable is loaded if caching is enabled
|
||||
_reload_envs_module()
|
||||
core_client_mod = _reload_core_client_module()
|
||||
|
||||
poll_timeouts: list[int] = []
|
||||
|
||||
class ShadowSocket:
|
||||
def poll(self, timeout: int) -> int:
|
||||
# Capture the timeout value for each poll call
|
||||
poll_timeouts.append(timeout)
|
||||
return 1
|
||||
|
||||
def recv_multipart(self):
|
||||
return (b"\x00\x00", b"ready")
|
||||
|
||||
class DummySocket:
|
||||
def send_multipart(self, _msg, *, copy: bool = False, track: bool = False):
|
||||
if track:
|
||||
return SimpleNamespace(done=True)
|
||||
|
||||
def recv_multipart(self, *, copy: bool = False):
|
||||
return (b"", b"")
|
||||
|
||||
def close(self, *, linger: int = 0):
|
||||
pass
|
||||
|
||||
def bind(self, _address):
|
||||
pass
|
||||
|
||||
def connect(self, _address):
|
||||
pass
|
||||
|
||||
def setsockopt(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(core_client_mod.zmq.Socket, "shadow", lambda *_: ShadowSocket())
|
||||
monkeypatch.setattr(
|
||||
core_client_mod, "make_zmq_socket", lambda *_, **__: DummySocket()
|
||||
)
|
||||
|
||||
parallel_config = SimpleNamespace(
|
||||
data_parallel_size=1,
|
||||
data_parallel_rank=0,
|
||||
data_parallel_size_local=1,
|
||||
data_parallel_rank_local=None,
|
||||
data_parallel_hybrid_lb=False,
|
||||
data_parallel_external_lb=False,
|
||||
)
|
||||
vllm_config = SimpleNamespace(parallel_config=parallel_config)
|
||||
|
||||
client = core_client_mod.MPClient(
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=object,
|
||||
log_stats=False,
|
||||
client_addresses={
|
||||
"input_address": "inproc://input",
|
||||
"output_address": "inproc://output",
|
||||
},
|
||||
)
|
||||
try:
|
||||
# timeout_value is in seconds, but poll receives milliseconds
|
||||
assert poll_timeouts == [timeout_value * 1000]
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
def loop_until_done(client: EngineCoreClient, outputs: dict):
|
||||
while True:
|
||||
engine_core_outputs = client.get_output().outputs
|
||||
|
||||
Reference in New Issue
Block a user