202 lines
6.7 KiB
Python
202 lines
6.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BASE_TEST_ENV = {
|
|
# The day vLLM said "hello world" on arxiv 🚀
|
|
"VLLM_SYSTEM_START_DATE": "2023-09-12",
|
|
}
|
|
DEFAULT_MAX_RETRIES = 3
|
|
|
|
|
|
@pytest.fixture
|
|
def pairs_of_event_types() -> dict[str, str]:
|
|
"""Links the 'done' event type with the corresponding 'start' event type.
|
|
|
|
This mapping should link all done <-> start events; if tests mean to
|
|
restrict the allowed events, they should filter this fixture to avoid
|
|
copy + paste errors in the mappings or unexpected KeyErrors due to missing
|
|
events.
|
|
"""
|
|
# fmt: off
|
|
event_pairs = {
|
|
"response.completed": "response.created",
|
|
"response.output_item.done": "response.output_item.added",
|
|
"response.content_part.done": "response.content_part.added",
|
|
"response.output_text.done": "response.output_text.delta",
|
|
"response.reasoning_text.done": "response.reasoning_text.delta",
|
|
"response.reasoning_part.done": "response.reasoning_part.added",
|
|
"response.mcp_call_arguments.done": "response.mcp_call_arguments.delta",
|
|
"response.mcp_call.completed": "response.mcp_call.in_progress",
|
|
"response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa: E501
|
|
"response.code_interpreter_call_code.done": "response.code_interpreter_call_code.delta", # noqa: E501
|
|
"response.web_search_call.completed": "response.web_search_call.in_progress",
|
|
}
|
|
# fmt: on
|
|
return event_pairs
|
|
|
|
|
|
async def retry_for_tool_call(
|
|
client,
|
|
*,
|
|
model: str,
|
|
expected_tool_type: str,
|
|
max_retries: int = DEFAULT_MAX_RETRIES,
|
|
**create_kwargs: Any,
|
|
):
|
|
"""Call ``client.responses.create`` up to *max_retries* times, returning
|
|
the first response that contains an output item of *expected_tool_type*.
|
|
|
|
Returns the **last** response if none match so the caller's assertions
|
|
fire with a clear diagnostic.
|
|
"""
|
|
last_response = None
|
|
for attempt in range(max_retries):
|
|
response = await client.responses.create(model=model, **create_kwargs)
|
|
last_response = response
|
|
if any(
|
|
getattr(item, "type", None) == expected_tool_type
|
|
for item in response.output
|
|
):
|
|
return response
|
|
assert last_response is not None
|
|
return last_response
|
|
|
|
|
|
async def retry_streaming_for(
|
|
client,
|
|
*,
|
|
model: str,
|
|
validate_events: Callable[[list], bool],
|
|
max_retries: int = DEFAULT_MAX_RETRIES,
|
|
**create_kwargs: Any,
|
|
) -> list:
|
|
"""Call ``client.responses.create(stream=True)`` up to *max_retries*
|
|
times, returning the first event list where *validate_events* returns
|
|
``True``.
|
|
"""
|
|
last_events: list = []
|
|
for attempt in range(max_retries):
|
|
stream = await client.responses.create(
|
|
model=model, stream=True, **create_kwargs
|
|
)
|
|
events: list = []
|
|
async for event in stream:
|
|
events.append(event)
|
|
last_events = events
|
|
if validate_events(events):
|
|
return events
|
|
return last_events
|
|
|
|
|
|
def has_output_type(response, type_name: str) -> bool:
|
|
"""Return True if *response* has at least one output item of *type_name*."""
|
|
return any(getattr(item, "type", None) == type_name for item in response.output)
|
|
|
|
|
|
def events_contain_type(events: list, type_substring: str) -> bool:
|
|
"""Return True if any event's type contains *type_substring*."""
|
|
return any(type_substring in getattr(e, "type", "") for e in events)
|
|
|
|
|
|
def validate_streaming_event_stack(
|
|
events: list, pairs_of_event_types: dict[str, str]
|
|
) -> None:
|
|
"""Validate that streaming events are properly nested/paired."""
|
|
stack: list[str] = []
|
|
for event in events:
|
|
etype = event.type
|
|
if etype == "response.created":
|
|
stack.append(etype)
|
|
elif etype == "response.completed":
|
|
assert stack and stack[-1] == pairs_of_event_types[etype], (
|
|
f"Unexpected stack top for {etype}: "
|
|
f"got {stack[-1] if stack else '<empty>'}"
|
|
)
|
|
stack.pop()
|
|
elif etype.endswith("added") or etype == "response.mcp_call.in_progress":
|
|
stack.append(etype)
|
|
elif etype.endswith("delta"):
|
|
if stack and stack[-1] == etype:
|
|
continue
|
|
stack.append(etype)
|
|
elif etype.endswith("done") or etype == "response.mcp_call.completed":
|
|
assert etype in pairs_of_event_types, f"Unknown done event: {etype}"
|
|
expected_start = pairs_of_event_types[etype]
|
|
assert stack and stack[-1] == expected_start, (
|
|
f"Stack mismatch for {etype}: "
|
|
f"expected {expected_start}, "
|
|
f"got {stack[-1] if stack else '<empty>'}"
|
|
)
|
|
stack.pop()
|
|
assert len(stack) == 0, f"Unclosed events on stack: {stack}"
|
|
|
|
|
|
def log_response_diagnostics(
|
|
response,
|
|
*,
|
|
label: str = "Response Diagnostics",
|
|
) -> dict[str, Any]:
|
|
"""Extract and log diagnostic info from a Responses API response.
|
|
|
|
Logs reasoning, tool-call attempts, MCP items, and output types so
|
|
that CI output (``pytest -s`` or ``--log-cli-level=INFO``) gives
|
|
full visibility into model behaviour even on passing runs.
|
|
|
|
Returns the extracted data so callers can make additional assertions
|
|
if needed.
|
|
"""
|
|
reasoning_texts = [
|
|
text
|
|
for item in response.output
|
|
if getattr(item, "type", None) == "reasoning"
|
|
for content in getattr(item, "content", [])
|
|
if (text := getattr(content, "text", None))
|
|
]
|
|
|
|
tool_call_attempts = [
|
|
{
|
|
"recipient": msg.get("recipient"),
|
|
"channel": msg.get("channel"),
|
|
}
|
|
for msg in response.output_messages
|
|
if (msg.get("recipient") or "").startswith("python")
|
|
]
|
|
|
|
mcp_items = [
|
|
{
|
|
"name": getattr(item, "name", None),
|
|
"status": getattr(item, "status", None),
|
|
}
|
|
for item in response.output
|
|
if getattr(item, "type", None) == "mcp_call"
|
|
]
|
|
|
|
output_types = [getattr(o, "type", None) for o in response.output]
|
|
|
|
diagnostics = {
|
|
"model_attempted_tool_calls": bool(tool_call_attempts),
|
|
"tool_call_attempts": tool_call_attempts,
|
|
"mcp_items": mcp_items,
|
|
"reasoning": reasoning_texts,
|
|
"output_text": response.output_text,
|
|
"output_types": output_types,
|
|
}
|
|
|
|
logger.info(
|
|
"\n====== %s ======\n%s\n==============================",
|
|
label,
|
|
json.dumps(diagnostics, indent=2, default=str),
|
|
)
|
|
|
|
return diagnostics
|