[CI][MCP][Harmony] Heavy refactoring Harmony & MCP response tests and stabilizing with deterministic test infrastructure (#33949)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -1,7 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BASE_TEST_ENV = {
|
||||
# The day vLLM said "hello world" on arxiv 🚀
|
||||
"VLLM_SYSTEM_START_DATE": "2023-09-12",
|
||||
}
|
||||
DEFAULT_MAX_RETRIES = 3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pairs_of_event_types() -> dict[str, str]:
|
||||
@@ -28,3 +43,159 @@ def pairs_of_event_types() -> dict[str, str]:
|
||||
}
|
||||
# fmt: on
|
||||
return event_pairs
|
||||
|
||||
|
||||
async def retry_for_tool_call(
|
||||
client,
|
||||
*,
|
||||
model: str,
|
||||
expected_tool_type: str,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
**create_kwargs: Any,
|
||||
):
|
||||
"""Call ``client.responses.create`` up to *max_retries* times, returning
|
||||
the first response that contains an output item of *expected_tool_type*.
|
||||
|
||||
Returns the **last** response if none match so the caller's assertions
|
||||
fire with a clear diagnostic.
|
||||
"""
|
||||
last_response = None
|
||||
for attempt in range(max_retries):
|
||||
response = await client.responses.create(model=model, **create_kwargs)
|
||||
last_response = response
|
||||
if any(
|
||||
getattr(item, "type", None) == expected_tool_type
|
||||
for item in response.output
|
||||
):
|
||||
return response
|
||||
assert last_response is not None
|
||||
return last_response
|
||||
|
||||
|
||||
async def retry_streaming_for(
|
||||
client,
|
||||
*,
|
||||
model: str,
|
||||
validate_events: Callable[[list], bool],
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
**create_kwargs: Any,
|
||||
) -> list:
|
||||
"""Call ``client.responses.create(stream=True)`` up to *max_retries*
|
||||
times, returning the first event list where *validate_events* returns
|
||||
``True``.
|
||||
"""
|
||||
last_events: list = []
|
||||
for attempt in range(max_retries):
|
||||
stream = await client.responses.create(
|
||||
model=model, stream=True, **create_kwargs
|
||||
)
|
||||
events: list = []
|
||||
async for event in stream:
|
||||
events.append(event)
|
||||
last_events = events
|
||||
if validate_events(events):
|
||||
return events
|
||||
return last_events
|
||||
|
||||
|
||||
def has_output_type(response, type_name: str) -> bool:
|
||||
"""Return True if *response* has at least one output item of *type_name*."""
|
||||
return any(getattr(item, "type", None) == type_name for item in response.output)
|
||||
|
||||
|
||||
def events_contain_type(events: list, type_substring: str) -> bool:
|
||||
"""Return True if any event's type contains *type_substring*."""
|
||||
return any(type_substring in getattr(e, "type", "") for e in events)
|
||||
|
||||
|
||||
def validate_streaming_event_stack(
|
||||
events: list, pairs_of_event_types: dict[str, str]
|
||||
) -> None:
|
||||
"""Validate that streaming events are properly nested/paired."""
|
||||
stack: list[str] = []
|
||||
for event in events:
|
||||
etype = event.type
|
||||
if etype == "response.created":
|
||||
stack.append(etype)
|
||||
elif etype == "response.completed":
|
||||
assert stack and stack[-1] == pairs_of_event_types[etype], (
|
||||
f"Unexpected stack top for {etype}: "
|
||||
f"got {stack[-1] if stack else '<empty>'}"
|
||||
)
|
||||
stack.pop()
|
||||
elif etype.endswith("added") or etype == "response.mcp_call.in_progress":
|
||||
stack.append(etype)
|
||||
elif etype.endswith("delta"):
|
||||
if stack and stack[-1] == etype:
|
||||
continue
|
||||
stack.append(etype)
|
||||
elif etype.endswith("done") or etype == "response.mcp_call.completed":
|
||||
assert etype in pairs_of_event_types, f"Unknown done event: {etype}"
|
||||
expected_start = pairs_of_event_types[etype]
|
||||
assert stack and stack[-1] == expected_start, (
|
||||
f"Stack mismatch for {etype}: "
|
||||
f"expected {expected_start}, "
|
||||
f"got {stack[-1] if stack else '<empty>'}"
|
||||
)
|
||||
stack.pop()
|
||||
assert len(stack) == 0, f"Unclosed events on stack: {stack}"
|
||||
|
||||
|
||||
def log_response_diagnostics(
|
||||
response,
|
||||
*,
|
||||
label: str = "Response Diagnostics",
|
||||
) -> dict[str, Any]:
|
||||
"""Extract and log diagnostic info from a Responses API response.
|
||||
|
||||
Logs reasoning, tool-call attempts, MCP items, and output types so
|
||||
that CI output (``pytest -s`` or ``--log-cli-level=INFO``) gives
|
||||
full visibility into model behaviour even on passing runs.
|
||||
|
||||
Returns the extracted data so callers can make additional assertions
|
||||
if needed.
|
||||
"""
|
||||
reasoning_texts = [
|
||||
text
|
||||
for item in response.output
|
||||
if getattr(item, "type", None) == "reasoning"
|
||||
for content in getattr(item, "content", [])
|
||||
if (text := getattr(content, "text", None))
|
||||
]
|
||||
|
||||
tool_call_attempts = [
|
||||
{
|
||||
"recipient": msg.get("recipient"),
|
||||
"channel": msg.get("channel"),
|
||||
}
|
||||
for msg in response.output_messages
|
||||
if (msg.get("recipient") or "").startswith("python")
|
||||
]
|
||||
|
||||
mcp_items = [
|
||||
{
|
||||
"name": getattr(item, "name", None),
|
||||
"status": getattr(item, "status", None),
|
||||
}
|
||||
for item in response.output
|
||||
if getattr(item, "type", None) == "mcp_call"
|
||||
]
|
||||
|
||||
output_types = [getattr(o, "type", None) for o in response.output]
|
||||
|
||||
diagnostics = {
|
||||
"model_attempted_tool_calls": bool(tool_call_attempts),
|
||||
"tool_call_attempts": tool_call_attempts,
|
||||
"mcp_items": mcp_items,
|
||||
"reasoning": reasoning_texts,
|
||||
"output_text": response.output_text,
|
||||
"output_types": output_types,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"\n====== %s ======\n%s\n==============================",
|
||||
label,
|
||||
json.dumps(diagnostics, indent=2, default=str),
|
||||
)
|
||||
|
||||
return diagnostics
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Integration tests for MCP tool support in the Responses API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@@ -10,11 +12,31 @@ from openai_harmony import ToolDescription, ToolNamespaceConfig
|
||||
from vllm.entrypoints.mcp.tool_server import MCPToolServer
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
from .conftest import (
|
||||
BASE_TEST_ENV,
|
||||
events_contain_type,
|
||||
log_response_diagnostics,
|
||||
retry_for_tool_call,
|
||||
retry_streaming_for,
|
||||
validate_streaming_event_stack,
|
||||
)
|
||||
|
||||
MODEL_NAME = "openai/gpt-oss-20b"
|
||||
|
||||
_BASE_SERVER_ARGS = [
|
||||
"--enforce-eager",
|
||||
"--tool-server",
|
||||
"demo",
|
||||
"--max_model_len",
|
||||
"5000",
|
||||
]
|
||||
|
||||
def test_get_tool_description():
|
||||
_PYTHON_TOOL_INSTRUCTION = (
|
||||
"You must use the Python tool to execute code. Never simulate execution."
|
||||
)
|
||||
|
||||
|
||||
class TestMCPToolServerUnit:
|
||||
"""Test MCPToolServer.get_tool_description filtering logic.
|
||||
|
||||
Note: The wildcard "*" is normalized to None by
|
||||
@@ -22,283 +44,240 @@ def test_get_tool_description():
|
||||
so we only test None and specific tool filtering here.
|
||||
See test_serving_responses.py for "*" normalization tests.
|
||||
"""
|
||||
pytest.importorskip("mcp")
|
||||
|
||||
server = MCPToolServer()
|
||||
tool1 = ToolDescription.new(
|
||||
name="tool1", description="First", parameters={"type": "object"}
|
||||
)
|
||||
tool2 = ToolDescription.new(
|
||||
name="tool2", description="Second", parameters={"type": "object"}
|
||||
)
|
||||
tool3 = ToolDescription.new(
|
||||
name="tool3", description="Third", parameters={"type": "object"}
|
||||
)
|
||||
def test_get_tool_description(self):
|
||||
pytest.importorskip("mcp")
|
||||
|
||||
server.harmony_tool_descriptions = {
|
||||
"test_server": ToolNamespaceConfig(
|
||||
name="test_server", description="test", tools=[tool1, tool2, tool3]
|
||||
server = MCPToolServer()
|
||||
tool1 = ToolDescription.new(
|
||||
name="tool1", description="First", parameters={"type": "object"}
|
||||
)
|
||||
tool2 = ToolDescription.new(
|
||||
name="tool2", description="Second", parameters={"type": "object"}
|
||||
)
|
||||
tool3 = ToolDescription.new(
|
||||
name="tool3", description="Third", parameters={"type": "object"}
|
||||
)
|
||||
}
|
||||
|
||||
# Nonexistent server
|
||||
assert server.get_tool_description("nonexistent") is None
|
||||
server.harmony_tool_descriptions = {
|
||||
"test_server": ToolNamespaceConfig(
|
||||
name="test_server",
|
||||
description="test",
|
||||
tools=[tool1, tool2, tool3],
|
||||
)
|
||||
}
|
||||
|
||||
# None (no filter) - returns all tools
|
||||
result = server.get_tool_description("test_server", allowed_tools=None)
|
||||
assert len(result.tools) == 3
|
||||
# Nonexistent server
|
||||
assert server.get_tool_description("nonexistent") is None
|
||||
|
||||
# Filter to specific tools
|
||||
result = server.get_tool_description(
|
||||
"test_server", allowed_tools=["tool1", "tool3"]
|
||||
)
|
||||
assert len(result.tools) == 2
|
||||
assert result.tools[0].name == "tool1"
|
||||
assert result.tools[1].name == "tool3"
|
||||
# None (no filter) - returns all tools
|
||||
result = server.get_tool_description("test_server", allowed_tools=None)
|
||||
assert len(result.tools) == 3
|
||||
|
||||
# Single tool
|
||||
result = server.get_tool_description(
|
||||
"test_server",
|
||||
allowed_tools=["tool2"],
|
||||
)
|
||||
assert len(result.tools) == 1
|
||||
assert result.tools[0].name == "tool2"
|
||||
# Filter to specific tools
|
||||
result = server.get_tool_description(
|
||||
"test_server", allowed_tools=["tool1", "tool3"]
|
||||
)
|
||||
assert len(result.tools) == 2
|
||||
assert result.tools[0].name == "tool1"
|
||||
assert result.tools[1].name == "tool3"
|
||||
|
||||
# No matching tools - returns None
|
||||
result = server.get_tool_description("test_server", allowed_tools=["nonexistent"])
|
||||
assert result is None
|
||||
# Single tool
|
||||
result = server.get_tool_description("test_server", allowed_tools=["tool2"])
|
||||
assert len(result.tools) == 1
|
||||
assert result.tools[0].name == "tool2"
|
||||
|
||||
# Empty list - returns None
|
||||
assert server.get_tool_description("test_server", allowed_tools=[]) is None
|
||||
# No matching tools - returns None
|
||||
result = server.get_tool_description(
|
||||
"test_server", allowed_tools=["nonexistent"]
|
||||
)
|
||||
assert result is None
|
||||
|
||||
# Empty list - returns None
|
||||
assert server.get_tool_description("test_server", allowed_tools=[]) is None
|
||||
|
||||
def test_builtin_tools_consistency(self):
|
||||
"""MCP_BUILTIN_TOOLS must match _BUILTIN_TOOL_TO_MCP_SERVER_LABEL values."""
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
_BUILTIN_TOOL_TO_MCP_SERVER_LABEL,
|
||||
MCP_BUILTIN_TOOLS,
|
||||
)
|
||||
|
||||
assert set(_BUILTIN_TOOL_TO_MCP_SERVER_LABEL.values()) == MCP_BUILTIN_TOOLS, (
|
||||
f"MCP_BUILTIN_TOOLS {MCP_BUILTIN_TOOLS} does not match "
|
||||
f"_BUILTIN_TOOL_TO_MCP_SERVER_LABEL values "
|
||||
f"{set(_BUILTIN_TOOL_TO_MCP_SERVER_LABEL.values())}"
|
||||
)
|
||||
|
||||
|
||||
class TestMCPEnabled:
|
||||
"""Tests that require MCP tools to be enabled via environment variable."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def monkeypatch_class(self):
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def mcp_enabled_server(self, monkeypatch_class: pytest.MonkeyPatch):
|
||||
args = ["--enforce-eager", "--tool-server", "demo"]
|
||||
|
||||
with monkeypatch_class.context() as m:
|
||||
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
|
||||
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
|
||||
m.setenv(
|
||||
"VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container"
|
||||
)
|
||||
# Helps the model follow instructions better
|
||||
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
def mcp_enabled_server(self):
|
||||
env_dict = {
|
||||
**BASE_TEST_ENV,
|
||||
"VLLM_ENABLE_RESPONSES_API_STORE": "1",
|
||||
"PYTHON_EXECUTION_BACKEND": "dangerously_use_uv",
|
||||
"VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS": ("code_interpreter,container"),
|
||||
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": "1",
|
||||
}
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME, list(_BASE_SERVER_ARGS), env_dict=env_dict
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mcp_enabled_client(self, mcp_enabled_server):
|
||||
async def client(self, mcp_enabled_server):
|
||||
async with mcp_enabled_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
@staticmethod
|
||||
def _mcp_tools_payload(*, allowed_tools: list[str] | None = None) -> list[dict]:
|
||||
tool: dict = {
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
"server_url": "http://localhost:8888",
|
||||
}
|
||||
if allowed_tools is not None:
|
||||
tool["allowed_tools"] = allowed_tools
|
||||
return [tool]
|
||||
|
||||
@staticmethod
|
||||
def _python_exec_input(code: str = "") -> str:
|
||||
if not code:
|
||||
code = "import random; print(random.randint(1, 1000000))"
|
||||
return f"Execute the following code: {code}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_mcp_tool_env_flag_enabled(
|
||||
self, mcp_enabled_client: OpenAI, model_name: str
|
||||
):
|
||||
response = await mcp_enabled_client.responses.create(
|
||||
async def test_mcp_tool_env_flag_enabled(self, client: OpenAI, model_name: str):
|
||||
response = await retry_for_tool_call(
|
||||
client,
|
||||
model=model_name,
|
||||
input=(
|
||||
"Execute the following code: "
|
||||
"import random; print(random.randint(1, 1000000))"
|
||||
),
|
||||
instructions=(
|
||||
"You must use the Python tool to execute code. "
|
||||
"Never simulate execution."
|
||||
),
|
||||
tools=[
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
# URL unused for DemoToolServer
|
||||
"server_url": "http://localhost:8888",
|
||||
}
|
||||
],
|
||||
expected_tool_type="mcp_call",
|
||||
input=self._python_exec_input(),
|
||||
instructions=_PYTHON_TOOL_INSTRUCTION,
|
||||
tools=self._mcp_tools_payload(),
|
||||
temperature=0.0,
|
||||
extra_body={"enable_response_messages": True},
|
||||
)
|
||||
assert response is not None
|
||||
|
||||
assert response.status == "completed"
|
||||
# Verify output messages: Tool calls and responses on analysis channel
|
||||
log_response_diagnostics(response, label="MCP Enabled")
|
||||
|
||||
tool_call_found = False
|
||||
tool_response_found = False
|
||||
for message in response.output_messages:
|
||||
recipient = message.get("recipient")
|
||||
if recipient and recipient.startswith("python"):
|
||||
tool_call_found = True
|
||||
assert message.get("channel") == "analysis", (
|
||||
"Tool call should be on analysis channel"
|
||||
)
|
||||
assert message.get("channel") == "analysis"
|
||||
author = message.get("author", {})
|
||||
if (
|
||||
author.get("role") == "tool"
|
||||
and author.get("name")
|
||||
and author.get("name").startswith("python")
|
||||
if author.get("role") == "tool" and (author.get("name") or "").startswith(
|
||||
"python"
|
||||
):
|
||||
tool_response_found = True
|
||||
assert message.get("channel") == "analysis", (
|
||||
"Tool response should be on analysis channel"
|
||||
)
|
||||
assert message.get("channel") == "analysis"
|
||||
|
||||
assert tool_call_found, "Should have found at least one Python tool call"
|
||||
assert tool_response_found, (
|
||||
"Should have found at least one Python tool response"
|
||||
assert tool_call_found, (
|
||||
f"No Python tool call found. "
|
||||
f"Output types: "
|
||||
f"{[getattr(o, 'type', None) for o in response.output]}"
|
||||
)
|
||||
for message in response.input_messages:
|
||||
assert message.get("author").get("role") != "developer", (
|
||||
"No developer messages should be present with valid mcp tool"
|
||||
)
|
||||
assert tool_response_found, "No Python tool response found"
|
||||
|
||||
for message in response.input_messages:
|
||||
assert message.get("author", {}).get("role") != "developer"
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_mcp_tool_with_allowed_tools_star(
|
||||
self, mcp_enabled_client: OpenAI, model_name: str
|
||||
self, client: OpenAI, model_name: str
|
||||
):
|
||||
"""Test MCP tool with allowed_tools=['*'] to select all available
|
||||
tools.
|
||||
|
||||
This E2E test verifies that the "*" wildcard works end-to-end.
|
||||
See test_serving_responses.py for detailed unit tests of "*"
|
||||
normalization.
|
||||
"""
|
||||
response = await mcp_enabled_client.responses.create(
|
||||
response = await retry_for_tool_call(
|
||||
client,
|
||||
model=model_name,
|
||||
input=(
|
||||
"Execute the following code: "
|
||||
"import random; print(random.randint(1, 1000000))"
|
||||
),
|
||||
instructions=(
|
||||
"You must use the Python tool to execute code. "
|
||||
"Never simulate execution."
|
||||
),
|
||||
tools=[
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
"server_url": "http://localhost:8888",
|
||||
# Using "*" to allow all tools from this MCP server
|
||||
"allowed_tools": ["*"],
|
||||
}
|
||||
],
|
||||
expected_tool_type="mcp_call",
|
||||
input=self._python_exec_input(),
|
||||
instructions=_PYTHON_TOOL_INSTRUCTION,
|
||||
tools=self._mcp_tools_payload(allowed_tools=["*"]),
|
||||
temperature=0.0,
|
||||
extra_body={"enable_response_messages": True},
|
||||
)
|
||||
assert response is not None
|
||||
|
||||
assert response.status == "completed"
|
||||
# Verify tool calls work with allowed_tools=["*"]
|
||||
tool_call_found = False
|
||||
for message in response.output_messages:
|
||||
recipient = message.get("recipient")
|
||||
if recipient and recipient.startswith("python"):
|
||||
tool_call_found = True
|
||||
break
|
||||
log_response_diagnostics(response, label="MCP Allowed Tools *")
|
||||
|
||||
tool_call_found = any(
|
||||
(msg.get("recipient") or "").startswith("python")
|
||||
for msg in response.output_messages
|
||||
)
|
||||
assert tool_call_found, (
|
||||
"Should have found at least one Python tool call with '*'"
|
||||
f"No Python tool call with '*'. "
|
||||
f"Output types: "
|
||||
f"{[getattr(o, 'type', None) for o in response.output]}"
|
||||
)
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_mcp_tool_calling_streaming_types(
|
||||
self,
|
||||
pairs_of_event_types: dict[str, str],
|
||||
mcp_enabled_client: OpenAI,
|
||||
client: OpenAI,
|
||||
model_name: str,
|
||||
):
|
||||
tools = [
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
}
|
||||
]
|
||||
input_text = "What is 123 * 456? Use python to calculate the result."
|
||||
def _has_mcp_events(events: list) -> bool:
|
||||
return events_contain_type(events, "mcp_call")
|
||||
|
||||
stream_response = await mcp_enabled_client.responses.create(
|
||||
events = await retry_streaming_for(
|
||||
client,
|
||||
model=model_name,
|
||||
input=input_text,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
instructions=(
|
||||
"You must use the Python tool to execute code. "
|
||||
"Never simulate execution."
|
||||
),
|
||||
validate_events=_has_mcp_events,
|
||||
input=("What is 123 * 456? Use Python to calculate the result."),
|
||||
tools=[{"type": "mcp", "server_label": "code_interpreter"}],
|
||||
instructions=_PYTHON_TOOL_INSTRUCTION,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
stack_of_event_types = []
|
||||
saw_mcp_type = False
|
||||
async for event in stream_response:
|
||||
if event.type == "response.created":
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type == "response.completed":
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
|
||||
stack_of_event_types.pop()
|
||||
elif (
|
||||
event.type.endswith("added")
|
||||
or event.type == "response.mcp_call.in_progress"
|
||||
):
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type.endswith("delta"):
|
||||
if stack_of_event_types[-1] == event.type:
|
||||
continue
|
||||
stack_of_event_types.append(event.type)
|
||||
elif (
|
||||
event.type.endswith("done")
|
||||
or event.type == "response.mcp_call.completed"
|
||||
):
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
|
||||
if "mcp_call" in event.type:
|
||||
saw_mcp_type = True
|
||||
stack_of_event_types.pop()
|
||||
validate_streaming_event_stack(events, pairs_of_event_types)
|
||||
|
||||
assert len(stack_of_event_types) == 0
|
||||
assert saw_mcp_type, "Should have seen at least one mcp call"
|
||||
assert events_contain_type(events, "mcp_call"), (
|
||||
f"No mcp_call events after retries. "
|
||||
f"Event types: {sorted({e.type for e in events})}"
|
||||
)
|
||||
|
||||
|
||||
class TestMCPDisabled:
|
||||
"""Tests that verify behavior when MCP tools are disabled."""
|
||||
"""Tests that MCP tools are not executed when the env flag is unset."""
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def monkeypatch_class(self):
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def mcp_disabled_server(self, monkeypatch_class: pytest.MonkeyPatch):
|
||||
args = ["--enforce-eager", "--tool-server", "demo"]
|
||||
|
||||
with monkeypatch_class.context() as m:
|
||||
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
|
||||
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
|
||||
# Helps the model follow instructions better
|
||||
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
def mcp_disabled_server(self):
|
||||
env_dict = {
|
||||
**BASE_TEST_ENV,
|
||||
"VLLM_ENABLE_RESPONSES_API_STORE": "1",
|
||||
"PYTHON_EXECUTION_BACKEND": "dangerously_use_uv",
|
||||
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": "1",
|
||||
}
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME, list(_BASE_SERVER_ARGS), env_dict=env_dict
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mcp_disabled_client(self, mcp_disabled_server):
|
||||
async def client(self, mcp_disabled_server):
|
||||
async with mcp_disabled_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_mcp_tool_env_flag_disabled(
|
||||
self, mcp_disabled_client: OpenAI, model_name: str
|
||||
async def test_mcp_disabled_server_does_not_execute(
|
||||
self, client: OpenAI, model_name: str
|
||||
):
|
||||
response = await mcp_disabled_client.responses.create(
|
||||
"""When MCP is disabled the model may still attempt tool calls
|
||||
(tool descriptions can remain in the prompt), but the server
|
||||
must NOT execute them."""
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=(
|
||||
"Execute the following code if the tool is present: "
|
||||
@@ -308,38 +287,35 @@ class TestMCPDisabled:
|
||||
{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
# URL unused for DemoToolServer
|
||||
"server_url": "http://localhost:8888",
|
||||
}
|
||||
],
|
||||
temperature=0.0,
|
||||
extra_body={"enable_response_messages": True},
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
# Verify output messages: No tool calls and responses
|
||||
tool_call_found = False
|
||||
tool_response_found = False
|
||||
|
||||
log_response_diagnostics(response, label="MCP Disabled")
|
||||
|
||||
# Server must not have executed any tool calls
|
||||
for message in response.output_messages:
|
||||
recipient = message.get("recipient")
|
||||
if recipient and recipient.startswith("python"):
|
||||
tool_call_found = True
|
||||
assert message.get("channel") == "analysis", (
|
||||
"Tool call should be on analysis channel"
|
||||
)
|
||||
author = message.get("author", {})
|
||||
if (
|
||||
assert not (
|
||||
author.get("role") == "tool"
|
||||
and author.get("name")
|
||||
and author.get("name").startswith("python")
|
||||
):
|
||||
tool_response_found = True
|
||||
assert message.get("channel") == "analysis", (
|
||||
"Tool response should be on analysis channel"
|
||||
and (author.get("name") or "").startswith("python")
|
||||
), (
|
||||
"Server executed a python tool call even though MCP is "
|
||||
f"disabled. Message: {message}"
|
||||
)
|
||||
|
||||
# No completed mcp_call output items
|
||||
for item in response.output:
|
||||
if getattr(item, "type", None) == "mcp_call":
|
||||
assert getattr(item, "status", None) != "completed", (
|
||||
"MCP call should not be completed when MCP is disabled"
|
||||
)
|
||||
|
||||
assert not tool_call_found, "Should not have a python call"
|
||||
assert not tool_response_found, "Should not have a tool response"
|
||||
# No developer messages injected
|
||||
for message in response.input_messages:
|
||||
assert message.get("author").get("role") != "developer", (
|
||||
"No developer messages should be present without a valid tool"
|
||||
)
|
||||
assert message.get("author", {}).get("role") != "developer"
|
||||
|
||||
@@ -3,15 +3,29 @@
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from openai import OpenAI
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
from .conftest import (
|
||||
BASE_TEST_ENV,
|
||||
has_output_type,
|
||||
log_response_diagnostics,
|
||||
retry_for_tool_call,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-8B"
|
||||
|
||||
_PYTHON_TOOL_INSTRUCTION = (
|
||||
"You must use the Python tool to execute code. "
|
||||
"Never simulate execution. You must print the final answer."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
@@ -32,12 +46,12 @@ def server():
|
||||
"--tool-server",
|
||||
"demo",
|
||||
]
|
||||
env_dict = dict(
|
||||
VLLM_ENABLE_RESPONSES_API_STORE="1",
|
||||
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1",
|
||||
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
|
||||
)
|
||||
|
||||
env_dict = {
|
||||
**BASE_TEST_ENV,
|
||||
"VLLM_ENABLE_RESPONSES_API_STORE": "1",
|
||||
"VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": "1",
|
||||
"PYTHON_EXECUTION_BACKEND": "dangerously_use_uv",
|
||||
}
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@@ -54,6 +68,7 @@ async def test_basic(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="What is 123 * 456?",
|
||||
temperature=0.0,
|
||||
)
|
||||
assert response is not None
|
||||
print("response: ", response)
|
||||
@@ -99,10 +114,15 @@ async def test_reasoning_and_function_items(client: OpenAI, model_name: str):
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
# make sure we get a reasoning and text output
|
||||
assert response.output[0].type == "reasoning"
|
||||
assert response.output[1].type == "message"
|
||||
assert type(response.output[1].content[0].text) is str
|
||||
|
||||
output_types = [getattr(o, "type", None) for o in response.output]
|
||||
assert "reasoning" in output_types, (
|
||||
f"Expected reasoning in output, got: {output_types}"
|
||||
)
|
||||
assert "message" in output_types, f"Expected message in output, got: {output_types}"
|
||||
|
||||
msg = next(o for o in response.output if o.type == "message")
|
||||
assert type(msg.content[0].text) is str
|
||||
|
||||
|
||||
def get_horoscope(sign):
|
||||
@@ -110,10 +130,10 @@ def get_horoscope(sign):
|
||||
|
||||
|
||||
def call_function(name, args):
|
||||
logger.info("Calling function %s with args %s", name, args)
|
||||
if name == "get_horoscope":
|
||||
return get_horoscope(**args)
|
||||
else:
|
||||
raise ValueError(f"Unknown function: {name}")
|
||||
raise ValueError(f"Unknown function: {name}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -136,61 +156,111 @@ async def test_function_call_first_turn(client: OpenAI, model_name: str):
|
||||
}
|
||||
]
|
||||
|
||||
response = await client.responses.create(
|
||||
response = await retry_for_tool_call(
|
||||
client,
|
||||
model=model_name,
|
||||
expected_tool_type="function_call",
|
||||
input="What is the horoscope for Aquarius today?",
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
assert len(response.output) == 2
|
||||
assert response.output[0].type == "reasoning"
|
||||
assert response.output[1].type == "function_call"
|
||||
|
||||
function_call = response.output[1]
|
||||
output_types = [getattr(o, "type", None) for o in response.output]
|
||||
assert "reasoning" in output_types, (
|
||||
f"Expected reasoning in output, got: {output_types}"
|
||||
)
|
||||
assert has_output_type(response, "function_call"), (
|
||||
f"Expected function_call in output, got: {output_types}"
|
||||
)
|
||||
|
||||
function_call = next(o for o in response.output if o.type == "function_call")
|
||||
assert function_call.name == "get_horoscope"
|
||||
assert function_call.call_id is not None
|
||||
|
||||
args = json.loads(function_call.arguments)
|
||||
assert "sign" in args
|
||||
|
||||
# the multi turn function call is tested above in
|
||||
# test_reasoning_and_function_items
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_mcp_tool_call(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
"""MCP tool calling with code_interpreter.
|
||||
|
||||
The model may make one or more tool calls before producing a final
|
||||
message. We validate server invariants (mcp_call items have correct
|
||||
fields) with hard assertions. Output indices are never hardcoded
|
||||
since the model can produce multiple tool-call rounds.
|
||||
"""
|
||||
# MCP + container init + code execution can be slow
|
||||
client_with_timeout = client.with_options(timeout=client.timeout * 3)
|
||||
|
||||
response = await retry_for_tool_call(
|
||||
client_with_timeout,
|
||||
model=model_name,
|
||||
input="What is 123 * 456? Use python to calculate the result.",
|
||||
expected_tool_type="mcp_call",
|
||||
input=(
|
||||
"What is 123 * 456? Use python to calculate the result. "
|
||||
"Print the result with print()."
|
||||
),
|
||||
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
|
||||
extra_body={"enable_response_messages": True},
|
||||
instructions=_PYTHON_TOOL_INSTRUCTION,
|
||||
temperature=0.0,
|
||||
extra_body={"enable_response_messages": True},
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
|
||||
# The model may produce multiple reasoning/mcp_call rounds before the
|
||||
# final message, so validate structurally rather than by exact index.
|
||||
output_types = [o.type for o in response.output]
|
||||
assert "reasoning" in output_types
|
||||
mcp_calls = [o for o in response.output if o.type == "mcp_call"]
|
||||
assert len(mcp_calls) >= 1
|
||||
assert type(mcp_calls[0].arguments) is str
|
||||
assert type(mcp_calls[0].output) is str
|
||||
output_types = [getattr(o, "type", None) for o in response.output]
|
||||
log_response_diagnostics(response, label="test_mcp_tool_call")
|
||||
|
||||
# The final output should be a message containing the correct answer
|
||||
assert response.output[-1].type == "message"
|
||||
assert any(s in response.output[-1].content[0].text for s in ("56088", "56,088"))
|
||||
assert response.status == "completed", (
|
||||
f"Response status={response.status} "
|
||||
f"(details={getattr(response, 'incomplete_details', None)}). "
|
||||
f"Output types: {output_types}."
|
||||
)
|
||||
|
||||
# Test raw input_messages / output_messages
|
||||
assert len(response.input_messages) == 1
|
||||
assert len(response.output_messages) >= 3
|
||||
assert "reasoning" in output_types, (
|
||||
f"Expected reasoning in output, got: {output_types}"
|
||||
)
|
||||
assert "mcp_call" in output_types, (
|
||||
f"Expected mcp_call in output, got: {output_types}"
|
||||
)
|
||||
|
||||
# Every mcp_call item must have well-typed fields
|
||||
for item in response.output:
|
||||
if getattr(item, "type", None) == "mcp_call":
|
||||
assert type(item.arguments) is str, (
|
||||
f"mcp_call.arguments should be str, got {type(item.arguments)}"
|
||||
)
|
||||
assert type(item.output) is str, (
|
||||
f"mcp_call.output should be str, got {type(item.output)}"
|
||||
)
|
||||
|
||||
# The model may make 1+ tool-call rounds but must still produce
|
||||
# a final message for a trivial calculation like 123 * 456.
|
||||
message_outputs = [
|
||||
o for o in response.output if getattr(o, "type", None) == "message"
|
||||
]
|
||||
assert message_outputs, (
|
||||
f"Model did not produce a final message. Output types: {output_types}"
|
||||
)
|
||||
|
||||
final_message = message_outputs[-1]
|
||||
assert any(s in final_message.content[0].text for s in ("56088", "56,088")), (
|
||||
f"Expected 56088 in final message, got: {final_message.content[0].text!r}"
|
||||
)
|
||||
|
||||
# Validate raw input_messages / output_messages
|
||||
assert len(response.input_messages) >= 1, "Expected at least 1 input message"
|
||||
assert len(response.output_messages) >= 1, "Expected at least 1 output message"
|
||||
assert any(
|
||||
s in response.output_messages[-1]["message"] for s in ("56088", "56,088")
|
||||
any(s in str(msg) for s in ("56088", "56,088"))
|
||||
for msg in response.output_messages
|
||||
), (
|
||||
f"Expected 56088 in at least one output_message, "
|
||||
f"got {len(response.output_messages)} messages"
|
||||
)
|
||||
|
||||
|
||||
@@ -202,6 +272,7 @@ async def test_max_tokens(client: OpenAI, model_name: str):
|
||||
input="What is the first paragraph of Moby Dick?",
|
||||
reasoning={"effort": "low"},
|
||||
max_output_tokens=30,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "incomplete"
|
||||
|
||||
@@ -12,13 +12,15 @@ MODEL_NAME = "Qwen/Qwen3-8B"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"]
|
||||
env_dict = dict(
|
||||
VLLM_ENABLE_RESPONSES_API_STORE="1",
|
||||
# uncomment for tool calling
|
||||
# PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
|
||||
)
|
||||
from .conftest import BASE_TEST_ENV
|
||||
|
||||
args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"]
|
||||
env_dict = {
|
||||
**BASE_TEST_ENV,
|
||||
"VLLM_ENABLE_RESPONSES_API_STORE": "1",
|
||||
# uncomment for tool calling
|
||||
# PYTHON_EXECUTION_BACKEND: "dangerously_use_uv",
|
||||
}
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
153
tests/utils.py
153
tests/utils.py
@@ -128,6 +128,9 @@ class RemoteOpenAIServer:
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
# Create a dedicated process group so we can kill
|
||||
# the entire tree (parent + EngineCore + workers) at once.
|
||||
start_new_session=True,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -189,6 +192,15 @@ class RemoteOpenAIServer:
|
||||
model_loader = get_model_loader(load_config)
|
||||
model_loader.download_model(model_config)
|
||||
|
||||
# Record GPU memory before server start so we know what
|
||||
# "released" looks like.
|
||||
self._pre_server_gpu_memory = self._get_gpu_memory_used()
|
||||
if self._pre_server_gpu_memory is not None:
|
||||
pre_gb = self._pre_server_gpu_memory / 1e9
|
||||
print(
|
||||
f"[RemoteOpenAIServer] GPU memory before server start: {pre_gb:.2f} GB"
|
||||
)
|
||||
|
||||
self._start_server(model, vllm_serve_args, env_dict)
|
||||
max_wait_seconds = max_wait_seconds or 360
|
||||
self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds)
|
||||
@@ -198,27 +210,69 @@ class RemoteOpenAIServer:
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pid = self.proc.pid
|
||||
# Graceful shutdown
|
||||
self.proc.terminate()
|
||||
|
||||
# Get the process group ID. Because we used
|
||||
# start_new_session=True the pgid equals the server's pid.
|
||||
try:
|
||||
pgid = os.getpgid(pid)
|
||||
except (ProcessLookupError, OSError):
|
||||
pgid = None
|
||||
|
||||
# Phase 1: graceful SIGTERM to the entire process group
|
||||
if pgid is not None:
|
||||
with contextlib.suppress(ProcessLookupError, OSError):
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
print(f"[RemoteOpenAIServer] Sent SIGTERM to process group {pgid}")
|
||||
else:
|
||||
self.proc.terminate()
|
||||
|
||||
try:
|
||||
self.proc.wait(timeout=15)
|
||||
print(f"[RemoteOpenAIServer] Server {pid} terminated gracefully")
|
||||
except subprocess.TimeoutExpired:
|
||||
# Phase 2: SIGKILL the entire process group
|
||||
print(
|
||||
f"[RemoteOpenAIServer] Server {pid} did not respond "
|
||||
"to SIGTERM, sending SIGKILL"
|
||||
"to SIGTERM, sending SIGKILL to process group"
|
||||
)
|
||||
self.proc.kill()
|
||||
if pgid is not None:
|
||||
with contextlib.suppress(ProcessLookupError, OSError):
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
else:
|
||||
self.proc.kill()
|
||||
|
||||
try:
|
||||
self.proc.wait(timeout=5)
|
||||
self.proc.wait(timeout=10)
|
||||
print(f"[RemoteOpenAIServer] Server {pid} killed")
|
||||
except subprocess.TimeoutExpired as err:
|
||||
raise RuntimeError(
|
||||
f"[RemoteOpenAIServer] Failed to kill server process {pid}"
|
||||
) from err
|
||||
# Wait for GPU memory to be released
|
||||
except subprocess.TimeoutExpired:
|
||||
# Phase 3: last resort - find and kill any orphaned children
|
||||
self._kill_orphaned_children(pid)
|
||||
|
||||
# Wait for GPU memory to actually be *freed*, not just
|
||||
# "stabilized at whatever level it's at".
|
||||
self._wait_for_gpu_memory_release()
|
||||
|
||||
def _kill_orphaned_children(self, parent_pid: int) -> None:
|
||||
"""Best-effort cleanup of any lingering child processes."""
|
||||
try:
|
||||
import psutil
|
||||
|
||||
parent = psutil.Process(parent_pid)
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
print(
|
||||
f"[RemoteOpenAIServer] Killing orphaned child "
|
||||
f"pid={child.pid} name={child.name()}"
|
||||
)
|
||||
child.kill()
|
||||
psutil.wait_procs(children, timeout=5)
|
||||
except Exception as e:
|
||||
# psutil may not be installed, or processes already gone
|
||||
print(f"[RemoteOpenAIServer] Orphan cleanup failed: {e}")
|
||||
# Fallback: try to kill by pgid one more time
|
||||
with contextlib.suppress(ProcessLookupError, OSError):
|
||||
os.killpg(parent_pid, signal.SIGKILL)
|
||||
|
||||
def _get_gpu_memory_used(self) -> float | None:
|
||||
"""Get total GPU memory used across all visible devices in bytes."""
|
||||
try:
|
||||
@@ -244,10 +298,26 @@ class RemoteOpenAIServer:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _wait_for_gpu_memory_release(self, timeout: float = 30.0):
|
||||
"""Poll GPU memory until it stabilizes, indicating cleanup is complete."""
|
||||
def _wait_for_gpu_memory_release(self, timeout: float = 60.0):
|
||||
"""Wait for GPU memory to drop back toward pre-server levels.
|
||||
|
||||
Two-phase strategy:
|
||||
1. Try to wait for memory to return close to pre-server baseline.
|
||||
2. If that doesn't happen, fall back to waiting for stabilization
|
||||
and log a warning (the next server might still OOM).
|
||||
"""
|
||||
baseline = self._pre_server_gpu_memory
|
||||
if baseline is None:
|
||||
# Can't query GPU memory - nothing to do
|
||||
return
|
||||
|
||||
# Allow up to 2 GiB overhead above baseline for driver/context state
|
||||
# that may persist between server instances.
|
||||
headroom_bytes = 2 * 1024 * 1024 * 1024
|
||||
target = baseline + headroom_bytes
|
||||
|
||||
start = time.time()
|
||||
prev_used: float | None = None
|
||||
last_used: float | None = None
|
||||
stable_count = 0
|
||||
|
||||
while time.time() - start < timeout:
|
||||
@@ -256,26 +326,49 @@ class RemoteOpenAIServer:
|
||||
if used is None:
|
||||
return # Can't query, assume ok
|
||||
|
||||
if prev_used is not None and abs(used - prev_used) < 100 * 1024 * 1024:
|
||||
stable_count += 1
|
||||
if stable_count >= 3:
|
||||
used_gb = used / 1e9
|
||||
print(
|
||||
f"[RemoteOpenAIServer] GPU memory stabilized "
|
||||
f"at {used_gb:.2f} GB"
|
||||
)
|
||||
return
|
||||
else:
|
||||
stable_count = 0
|
||||
used_gb = used / 1e9
|
||||
target_gb = target / 1e9
|
||||
elapsed = time.time() - start
|
||||
|
||||
prev_used = used
|
||||
time.sleep(0.1)
|
||||
# Phase 1: memory dropped to near baseline - we're done.
|
||||
if used <= target:
|
||||
print(
|
||||
f"[RemoteOpenAIServer] GPU memory released to "
|
||||
f"{used_gb:.2f} GB (target: {target_gb:.2f} GB) "
|
||||
f"in {elapsed:.1f}s"
|
||||
)
|
||||
return
|
||||
|
||||
last_reading = prev_used / 1e9 if prev_used is not None else 0.0
|
||||
# Phase 2 (after 40s): fall back to stabilization check.
|
||||
# This handles cases where another process is using GPU memory
|
||||
# and we'll never reach baseline.
|
||||
if elapsed > 40.0 and last_used is not None:
|
||||
delta = abs(used - last_used)
|
||||
if delta < 200 * 1024 * 1024: # 200 MB
|
||||
stable_count += 1
|
||||
if stable_count >= 3:
|
||||
print(
|
||||
f"[RemoteOpenAIServer] WARNING: GPU memory "
|
||||
f"stabilized at {used_gb:.2f} GB "
|
||||
f"(target was {target_gb:.2f} GB). "
|
||||
f"Proceeding - next server may OOM."
|
||||
)
|
||||
return
|
||||
else:
|
||||
stable_count = 0
|
||||
|
||||
last_used = used
|
||||
time.sleep(1.0)
|
||||
|
||||
# Timeout - log clearly so CI failures are diagnosable
|
||||
final_used = self._get_gpu_memory_used()
|
||||
final_gb = final_used / 1e9 if final_used else 0.0
|
||||
raise RuntimeError(
|
||||
f"[RemoteOpenAIServer] GPU memory did not stabilize within {timeout}s. "
|
||||
f"Last reading: {last_reading:.2f} GB. "
|
||||
"Child processes may still be holding GPU memory."
|
||||
f"[RemoteOpenAIServer] GPU memory did not release within "
|
||||
f"{timeout}s. Current: {final_gb:.2f} GB, "
|
||||
f"target: {target / 1e9:.2f} GB, "
|
||||
f"baseline: {baseline / 1e9:.2f} GB. "
|
||||
f"Child processes may still be holding GPU memory."
|
||||
)
|
||||
|
||||
def _poll(self) -> int | None:
|
||||
|
||||
@@ -48,8 +48,11 @@ from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseInputOutputItem,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
REASONING_EFFORT = {
|
||||
"high": ReasoningEffort.HIGH,
|
||||
"medium": ReasoningEffort.MEDIUM,
|
||||
@@ -62,20 +65,15 @@ _harmony_encoding = None
|
||||
# they are available and requested by the user.
|
||||
# Tool args are provided by MCP tool descriptions. Output
|
||||
# of the tools are stringified.
|
||||
MCP_BUILTIN_TOOLS: set[str] = {
|
||||
"web_search_preview",
|
||||
"code_interpreter",
|
||||
"container",
|
||||
}
|
||||
|
||||
# Mapping from built-in tool recipient names to their MCP server labels.
|
||||
# This ensures consistency between streaming and non-streaming responses.
|
||||
_BUILTIN_TOOL_TO_MCP_SERVER_LABEL: dict[str, str] = {
|
||||
"python": "code_interpreter",
|
||||
"browser": "web_search_preview",
|
||||
"container": "container",
|
||||
}
|
||||
|
||||
# Derive MCP_BUILTIN_TOOLS from the canonical mapping
|
||||
MCP_BUILTIN_TOOLS: set[str] = set(_BUILTIN_TOOL_TO_MCP_SERVER_LABEL.values())
|
||||
|
||||
|
||||
def has_custom_tools(tool_types: set[str]) -> bool:
|
||||
"""
|
||||
@@ -116,8 +114,11 @@ def get_system_message(
|
||||
REASONING_EFFORT[reasoning_effort]
|
||||
)
|
||||
if start_date is None:
|
||||
# NOTE(woosuk): This brings non-determinism in vLLM. Be careful.
|
||||
start_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
# NOTE(woosuk): This brings non-determinism in vLLM.
|
||||
# Set VLLM_SYSTEM_START_DATE to pin it.
|
||||
start_date = envs.VLLM_SYSTEM_START_DATE or datetime.datetime.now().strftime(
|
||||
"%Y-%m-%d"
|
||||
)
|
||||
sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
|
||||
if browser_description is not None:
|
||||
sys_msg_content = sys_msg_content.with_tools(browser_description)
|
||||
@@ -398,15 +399,60 @@ def parse_chat_input_to_harmony_message(
|
||||
|
||||
|
||||
def parse_input_to_harmony_message(chat_msg) -> list[Message]:
|
||||
"""
|
||||
Parse a message from request.previous_input_messages in the Responsees API to
|
||||
Harmony messages.
|
||||
"""Parse a message from request.previous_input_messages
|
||||
into Harmony messages.
|
||||
|
||||
Supports both OpenAI chat format ({"role": "..."}) and
|
||||
Harmony format ({"author": {"role": "..."}}).
|
||||
"""
|
||||
if not isinstance(chat_msg, dict):
|
||||
# Handle Pydantic models
|
||||
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||
|
||||
if "author" in chat_msg and isinstance(chat_msg.get("author"), dict):
|
||||
return [_parse_harmony_format_message(chat_msg)]
|
||||
|
||||
return _parse_chat_format_message(chat_msg)
|
||||
|
||||
|
||||
def _parse_harmony_format_message(chat_msg: dict) -> Message:
|
||||
"""Reconstruct a Message from Harmony-format dict,
|
||||
preserving channel, recipient, and content_type."""
|
||||
author_dict = chat_msg["author"]
|
||||
role = author_dict.get("role")
|
||||
name = author_dict.get("name")
|
||||
|
||||
raw_content = chat_msg.get("content", "")
|
||||
if isinstance(raw_content, list):
|
||||
# TODO: Support refusal and non-text content types.
|
||||
contents = [TextContent(text=c.get("text", "")) for c in raw_content]
|
||||
elif isinstance(raw_content, str):
|
||||
contents = [TextContent(text=raw_content)]
|
||||
else:
|
||||
contents = [TextContent(text="")]
|
||||
|
||||
if name:
|
||||
msg = Message.from_author_and_contents(Author.new(Role(role), name), contents)
|
||||
else:
|
||||
msg = Message.from_role_and_contents(Role(role), contents)
|
||||
|
||||
channel = chat_msg.get("channel")
|
||||
if channel:
|
||||
msg = msg.with_channel(channel)
|
||||
recipient = chat_msg.get("recipient")
|
||||
if recipient:
|
||||
msg = msg.with_recipient(recipient)
|
||||
content_type = chat_msg.get("content_type")
|
||||
if content_type:
|
||||
msg = msg.with_content_type(content_type)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
def _parse_chat_format_message(chat_msg: dict) -> list[Message]:
|
||||
"""Parse an OpenAI chat-format dict into Harmony messages."""
|
||||
role = chat_msg.get("role")
|
||||
if role is None:
|
||||
raise ValueError(f"Message has no 'role' key: {chat_msg}")
|
||||
|
||||
# Assistant message with tool calls
|
||||
tool_calls = chat_msg.get("tool_calls")
|
||||
@@ -426,15 +472,21 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:
|
||||
# Tool role message (tool output)
|
||||
if role == "tool":
|
||||
name = chat_msg.get("name", "")
|
||||
if name and not name.startswith("functions."):
|
||||
name = f"functions.{name}"
|
||||
content = chat_msg.get("content", "") or ""
|
||||
content = flatten_chat_text_content(content)
|
||||
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{name}"), content
|
||||
).with_channel("commentary")
|
||||
# NOTE: .with_recipient("assistant") is required on tool messages
|
||||
# to match parse_chat_input_to_harmony_message behavior and ensure
|
||||
# proper routing in the Harmony protocol.
|
||||
msg = (
|
||||
Message.from_author_and_content(Author.new(Role.TOOL, name), content)
|
||||
.with_channel("commentary")
|
||||
.with_recipient("assistant")
|
||||
)
|
||||
return [msg]
|
||||
|
||||
# Default: user/assistant/system messages with content
|
||||
# Default: user/assistant/system messages
|
||||
content = chat_msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
contents = [TextContent(text=content)]
|
||||
@@ -497,6 +549,10 @@ def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutput
|
||||
try:
|
||||
browser_call = json.loads(content.text)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Invalid JSON in browser tool call, using error placeholder: %s",
|
||||
content.text,
|
||||
)
|
||||
json_retry_output_message = (
|
||||
f"Invalid JSON args, caught and retried: {content.text}"
|
||||
)
|
||||
@@ -730,22 +786,7 @@ def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]:
|
||||
)
|
||||
]
|
||||
|
||||
if parser.current_channel == "commentary":
|
||||
return [
|
||||
ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=parser.current_content, type="reasoning_text"
|
||||
)
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
]
|
||||
|
||||
if parser.current_channel == "analysis":
|
||||
if parser.current_channel in ("commentary", "analysis"):
|
||||
return [
|
||||
ResponseReasoningItem(
|
||||
id=f"rs_{random_uuid()}",
|
||||
|
||||
@@ -346,17 +346,17 @@ class ParsableContext(ConversationContext):
|
||||
self.parser.response_messages.extend(output)
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
"""Return true if the last message is a MCP tool call"""
|
||||
"""Return true if the last message is a builtin tool call
|
||||
that the request has enabled."""
|
||||
last_message = self.parser.response_messages[-1]
|
||||
# TODO(qandrew): figure out which tools are MCP tools
|
||||
if last_message.type == "function_call": # noqa: SIM102
|
||||
if last_message.name in (
|
||||
"code_interpreter",
|
||||
"python",
|
||||
"web_search_preview",
|
||||
) or last_message.name.startswith("container"):
|
||||
return True
|
||||
|
||||
if last_message.type != "function_call":
|
||||
return False
|
||||
if last_message.name in ("code_interpreter", "python"):
|
||||
return "python" in self.available_tools
|
||||
if last_message.name == "web_search_preview":
|
||||
return "browser" in self.available_tools
|
||||
if last_message.name.startswith("container"):
|
||||
return "container" in self.available_tools
|
||||
return False
|
||||
|
||||
async def call_python_tool(
|
||||
@@ -665,11 +665,15 @@ class HarmonyContext(ConversationContext):
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
last_msg = self.messages[-1]
|
||||
recipient = last_msg.recipient
|
||||
return recipient is not None and (
|
||||
recipient.startswith("browser.")
|
||||
or recipient.startswith("python")
|
||||
or recipient.startswith("container.")
|
||||
)
|
||||
if recipient is None:
|
||||
return False
|
||||
if recipient.startswith("browser."):
|
||||
return "browser" in self.available_tools
|
||||
if recipient.startswith("python"):
|
||||
return "python" in self.available_tools
|
||||
if recipient.startswith("container."):
|
||||
return "container" in self.available_tools
|
||||
return False
|
||||
|
||||
async def call_tool(self) -> list[Message]:
|
||||
if not self.messages:
|
||||
|
||||
@@ -392,13 +392,27 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
max_model_len = self.model_config.max_model_len
|
||||
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
||||
|
||||
# Only include builtin tools that the request actually asked for.
|
||||
# Without this filter, tools registered on the server (e.g. via
|
||||
# --tool-server demo) would be available for execution even when
|
||||
# the request didn't enable them.
|
||||
requested_tool_types = extract_tool_types(request.tools)
|
||||
builtin_tool_list: list[str] = []
|
||||
if self.tool_server is not None:
|
||||
if self.tool_server.has_tool("browser"):
|
||||
if (
|
||||
self.tool_server.has_tool("browser")
|
||||
and "web_search_preview" in requested_tool_types
|
||||
):
|
||||
builtin_tool_list.append("browser")
|
||||
if self.tool_server.has_tool("python"):
|
||||
if (
|
||||
self.tool_server.has_tool("python")
|
||||
and "code_interpreter" in requested_tool_types
|
||||
):
|
||||
builtin_tool_list.append("python")
|
||||
if self.tool_server.has_tool("container"):
|
||||
if (
|
||||
self.tool_server.has_tool("container")
|
||||
and "container" in requested_tool_types
|
||||
):
|
||||
builtin_tool_list.append("container")
|
||||
|
||||
if self.tool_server is not None:
|
||||
@@ -1049,9 +1063,15 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# FIXME(woosuk): Currently, request params like reasoning and
|
||||
# instructions are ignored.
|
||||
prev_msgs = self.msg_store[prev_response.id]
|
||||
# Remove the previous chain-of-thoughts if there is a new "final"
|
||||
# message. Note that this also removes these messages from the
|
||||
# msg_store.
|
||||
|
||||
# FIXME(woosuk): The slice-delete-reappend cycle below is
|
||||
# currently a no-op --- it removes messages then puts them all
|
||||
# back unfiltered. It may be intentionally deferred (see FIXME
|
||||
# above) or redundant if the Harmony encoder already strips
|
||||
# analysis messages at render time. If analysis messages need
|
||||
# to be dropped here, add a channel != "analysis" filter when
|
||||
# re-appending, similar to auto_drop_analysis_messages in
|
||||
# harmony_utils.py.
|
||||
if len(prev_msgs) > 0:
|
||||
last_msg = prev_msgs[-1]
|
||||
assert isinstance(last_msg, OpenAIHarmonyMessage)
|
||||
@@ -1072,7 +1092,11 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# Append the new input.
|
||||
# Responses API supports simple text inputs without chat format.
|
||||
if isinstance(request.input, str):
|
||||
messages.append(get_user_message(request.input))
|
||||
# Skip empty string input when previous_input_messages supplies
|
||||
# the full conversation history --- an empty trailing user message
|
||||
# confuses the model into thinking nothing was sent.
|
||||
if request.input or not request.previous_input_messages:
|
||||
messages.append(get_user_message(request.input))
|
||||
else:
|
||||
if prev_response is not None:
|
||||
prev_outputs = copy(prev_response.output)
|
||||
|
||||
@@ -209,6 +209,7 @@ if TYPE_CHECKING:
|
||||
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set()
|
||||
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: bool = False
|
||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||
VLLM_SYSTEM_START_DATE: str | None = None
|
||||
VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
|
||||
@@ -1458,6 +1459,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool(
|
||||
int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))
|
||||
),
|
||||
# Pin the conversation start date injected into the Harmony system
|
||||
# message. When unset the current date is used, which introduces
|
||||
# non-determinism (different tokens -> different model behaviour at
|
||||
# temperature=0). Set to an ISO date string, e.g. "2023-09-12",
|
||||
# for reproducible inference or testing.
|
||||
"VLLM_SYSTEM_START_DATE": lambda: os.getenv("VLLM_SYSTEM_START_DATE", None),
|
||||
# Enable automatic retry when tool call JSON parsing fails
|
||||
# If enabled, returns an error message to the model to retry
|
||||
# If disabled (default), raises an exception and fails the request
|
||||
|
||||
Reference in New Issue
Block a user