[Refactor] Relocate entrypoint tests to match serving code structure (#37593)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
740
tests/entrypoints/openai/responses/test_responses_utils.py
Normal file
740
tests/entrypoints/openai/responses/test_responses_utils.py
Normal file
@@ -0,0 +1,740 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
)
|
||||
from openai.types.responses.response_output_message import ResponseOutputMessage
|
||||
from openai.types.responses.response_output_text import ResponseOutputText
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content,
|
||||
ResponseReasoningItem,
|
||||
Summary,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.constants import MCP_PREFIX
|
||||
from vllm.entrypoints.openai.responses.utils import (
|
||||
_construct_single_message_from_response_item,
|
||||
_maybe_combine_reasoning_and_tool_call,
|
||||
construct_chat_messages_with_tool_call,
|
||||
convert_tool_responses_to_completions_format,
|
||||
should_continue_final_message,
|
||||
)
|
||||
|
||||
|
||||
class TestResponsesUtils:
|
||||
"""Tests for convert_tool_responses_to_completions_format function."""
|
||||
|
||||
def test_convert_tool_responses_to_completions_format(self):
|
||||
"""Test basic conversion of a flat tool schema to nested format."""
|
||||
input_tool = {
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location", "unit"],
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_tool_responses_to_completions_format(input_tool)
|
||||
|
||||
assert result == {"type": "function", "function": input_tool}
|
||||
|
||||
def test_construct_chat_messages_with_tool_call(self):
|
||||
"""Test construction of chat messages with tool calls."""
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id="lol",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text="Leroy Jenkins",
|
||||
type="reasoning_text",
|
||||
)
|
||||
],
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
mcp_tool_item = ResponseFunctionToolCall(
|
||||
id="mcp_123",
|
||||
call_id="call_123",
|
||||
type="function_call",
|
||||
status="completed",
|
||||
name="python",
|
||||
arguments='{"code": "123+456"}',
|
||||
)
|
||||
input_items = [reasoning_item, mcp_tool_item]
|
||||
messages = construct_chat_messages_with_tool_call(input_items)
|
||||
|
||||
assert len(messages) == 1
|
||||
message = messages[0]
|
||||
assert message["role"] == "assistant"
|
||||
assert message["reasoning"] == "Leroy Jenkins"
|
||||
assert message["tool_calls"][0]["id"] == "call_123"
|
||||
assert message["tool_calls"][0]["function"]["name"] == "python"
|
||||
assert (
|
||||
message["tool_calls"][0]["function"]["arguments"] == '{"code": "123+456"}'
|
||||
)
|
||||
|
||||
def test_construct_single_message_from_response_item(self):
|
||||
item = ResponseReasoningItem(
|
||||
id="lol",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text="Leroy Jenkins",
|
||||
type="reasoning_text",
|
||||
)
|
||||
],
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
formatted_item = _construct_single_message_from_response_item(item)
|
||||
assert formatted_item["role"] == "assistant"
|
||||
assert formatted_item["reasoning"] == "Leroy Jenkins"
|
||||
|
||||
item = ResponseReasoningItem(
|
||||
id="lol",
|
||||
summary=[
|
||||
Summary(
|
||||
text='Hmm, the user has just started with a simple "Hello,"',
|
||||
type="summary_text",
|
||||
)
|
||||
],
|
||||
type="reasoning",
|
||||
content=None,
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
|
||||
formatted_item = _construct_single_message_from_response_item(item)
|
||||
assert formatted_item["role"] == "assistant"
|
||||
assert (
|
||||
formatted_item["reasoning"]
|
||||
== 'Hmm, the user has just started with a simple "Hello,"'
|
||||
)
|
||||
|
||||
tool_call_output = ResponseFunctionToolCallOutputItem(
|
||||
id="temp_id",
|
||||
type="function_call_output",
|
||||
call_id="temp",
|
||||
output="1234",
|
||||
status="completed",
|
||||
)
|
||||
formatted_item = _construct_single_message_from_response_item(tool_call_output)
|
||||
assert formatted_item["role"] == "tool"
|
||||
assert formatted_item["content"] == "1234"
|
||||
assert formatted_item["tool_call_id"] == "temp"
|
||||
|
||||
item = ResponseReasoningItem(
|
||||
id="lol",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=None,
|
||||
encrypted_content="TOP_SECRET_MESSAGE",
|
||||
status=None,
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
_construct_single_message_from_response_item(item)
|
||||
|
||||
output_item = ResponseOutputMessage(
|
||||
id="msg_bf585bbbe3d500e0",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
annotations=[],
|
||||
text="dongyi",
|
||||
type="output_text",
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
|
||||
formatted_item = _construct_single_message_from_response_item(output_item)
|
||||
assert formatted_item["role"] == "assistant"
|
||||
assert formatted_item["content"] == "dongyi"
|
||||
|
||||
|
||||
class TestReasoningItemContentPriority:
|
||||
"""Tests that content is prioritized over summary for reasoning items."""
|
||||
|
||||
def test_content_preferred_over_summary(self):
|
||||
"""When both content and summary are present, content should win."""
|
||||
item = ResponseReasoningItem(
|
||||
id="reasoning_1",
|
||||
summary=[
|
||||
Summary(
|
||||
text="This is a summary",
|
||||
type="summary_text",
|
||||
)
|
||||
],
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text="This is the actual content",
|
||||
type="reasoning_text",
|
||||
)
|
||||
],
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
formatted = _construct_single_message_from_response_item(item)
|
||||
assert formatted["reasoning"] == "This is the actual content"
|
||||
|
||||
def test_content_only(self):
|
||||
"""When only content is present (no summary), content is used."""
|
||||
item = ResponseReasoningItem(
|
||||
id="reasoning_2",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text="Content without summary",
|
||||
type="reasoning_text",
|
||||
)
|
||||
],
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
formatted = _construct_single_message_from_response_item(item)
|
||||
assert formatted["reasoning"] == "Content without summary"
|
||||
|
||||
@patch("vllm.entrypoints.openai.responses.utils.logger")
|
||||
def test_summary_fallback_when_no_content(self, mock_logger):
|
||||
"""When content is absent, summary is used as fallback with warning."""
|
||||
item = ResponseReasoningItem(
|
||||
id="reasoning_3",
|
||||
summary=[
|
||||
Summary(
|
||||
text="Fallback summary text",
|
||||
type="summary_text",
|
||||
)
|
||||
],
|
||||
type="reasoning",
|
||||
content=None,
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
formatted = _construct_single_message_from_response_item(item)
|
||||
assert formatted["reasoning"] == "Fallback summary text"
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert (
|
||||
"summary text as reasoning content" in mock_logger.warning.call_args[0][0]
|
||||
)
|
||||
|
||||
@patch("vllm.entrypoints.openai.responses.utils.logger")
|
||||
def test_summary_fallback_when_content_empty(self, mock_logger):
|
||||
"""When content is an empty list, summary is used as fallback."""
|
||||
item = ResponseReasoningItem(
|
||||
id="reasoning_4",
|
||||
summary=[
|
||||
Summary(
|
||||
text="Summary when content empty",
|
||||
type="summary_text",
|
||||
)
|
||||
],
|
||||
type="reasoning",
|
||||
content=[],
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
formatted = _construct_single_message_from_response_item(item)
|
||||
assert formatted["reasoning"] == "Summary when content empty"
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert (
|
||||
"summary text as reasoning content" in mock_logger.warning.call_args[0][0]
|
||||
)
|
||||
|
||||
def test_neither_content_nor_summary(self):
|
||||
"""When neither content nor summary is present, reasoning is empty."""
|
||||
item = ResponseReasoningItem(
|
||||
id="reasoning_5",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=None,
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
formatted = _construct_single_message_from_response_item(item)
|
||||
assert formatted["reasoning"] == ""
|
||||
|
||||
def test_encrypted_content_raises(self):
|
||||
"""Encrypted content should still raise ValueError."""
|
||||
item = ResponseReasoningItem(
|
||||
id="reasoning_6",
|
||||
summary=[
|
||||
Summary(
|
||||
text="Some summary",
|
||||
type="summary_text",
|
||||
)
|
||||
],
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text="Some content",
|
||||
type="reasoning_text",
|
||||
)
|
||||
],
|
||||
encrypted_content="ENCRYPTED",
|
||||
status=None,
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
_construct_single_message_from_response_item(item)
|
||||
|
||||
@patch("vllm.entrypoints.openai.responses.utils.logger")
|
||||
def test_summary_with_multiple_entries_uses_first(self, mock_logger):
|
||||
"""When multiple summary entries exist, the first one is used."""
|
||||
item = ResponseReasoningItem(
|
||||
id="reasoning_7",
|
||||
summary=[
|
||||
Summary(
|
||||
text="First summary",
|
||||
type="summary_text",
|
||||
),
|
||||
Summary(
|
||||
text="Second summary",
|
||||
type="summary_text",
|
||||
),
|
||||
],
|
||||
type="reasoning",
|
||||
content=None,
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
formatted = _construct_single_message_from_response_item(item)
|
||||
assert formatted["reasoning"] == "First summary"
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert (
|
||||
"summary text as reasoning content" in mock_logger.warning.call_args[0][0]
|
||||
)
|
||||
|
||||
@patch("vllm.entrypoints.openai.responses.utils.logger")
|
||||
def test_no_warning_when_content_used(self, mock_logger):
|
||||
"""No warning should be emitted when content is available."""
|
||||
item = ResponseReasoningItem(
|
||||
id="reasoning_8",
|
||||
summary=[
|
||||
Summary(
|
||||
text="Summary text",
|
||||
type="summary_text",
|
||||
)
|
||||
],
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text="Content text",
|
||||
type="reasoning_text",
|
||||
)
|
||||
],
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
_construct_single_message_from_response_item(item)
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
|
||||
class TestShouldContinueFinalMessage:
|
||||
"""Tests for should_continue_final_message function.
|
||||
|
||||
This function enables Anthropic-style partial message completion, where
|
||||
users can provide an incomplete assistant message and have the model
|
||||
continue from where it left off.
|
||||
"""
|
||||
|
||||
def test_string_input_returns_false(self):
|
||||
"""String input is always a user message, so should not continue."""
|
||||
assert should_continue_final_message("Hello, world!") is False
|
||||
|
||||
def test_empty_list_returns_false(self):
|
||||
"""Empty list should not continue."""
|
||||
assert should_continue_final_message([]) is False
|
||||
|
||||
def test_completed_message_returns_false(self):
|
||||
"""Completed message should not be continued."""
|
||||
output_item = ResponseOutputMessage(
|
||||
id="msg_123",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
annotations=[],
|
||||
text="The answer is 42.",
|
||||
type="output_text",
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
assert should_continue_final_message([output_item]) is False
|
||||
|
||||
def test_in_progress_message_returns_true(self):
|
||||
"""In-progress message should be continued.
|
||||
|
||||
This is the key use case for partial message completion.
|
||||
Example: The user provides "The best answer is (" and wants
|
||||
the model to continue from there.
|
||||
"""
|
||||
output_item = ResponseOutputMessage(
|
||||
id="msg_123",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
annotations=[],
|
||||
text="The best answer is (",
|
||||
type="output_text",
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="in_progress",
|
||||
type="message",
|
||||
)
|
||||
assert should_continue_final_message([output_item]) is True
|
||||
|
||||
def test_incomplete_message_returns_true(self):
|
||||
"""Incomplete message should be continued."""
|
||||
output_item = ResponseOutputMessage(
|
||||
id="msg_123",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
annotations=[],
|
||||
text="The answer",
|
||||
type="output_text",
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="incomplete",
|
||||
type="message",
|
||||
)
|
||||
assert should_continue_final_message([output_item]) is True
|
||||
|
||||
def test_in_progress_reasoning_returns_true(self):
|
||||
"""In-progress reasoning should be continued."""
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id="reasoning_123",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text="Let me think about this...",
|
||||
type="reasoning_text",
|
||||
)
|
||||
],
|
||||
encrypted_content=None,
|
||||
status="in_progress",
|
||||
)
|
||||
assert should_continue_final_message([reasoning_item]) is True
|
||||
|
||||
def test_incomplete_reasoning_returns_true(self):
|
||||
"""Incomplete reasoning should be continued."""
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id="reasoning_123",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text="Let me think",
|
||||
type="reasoning_text",
|
||||
)
|
||||
],
|
||||
encrypted_content=None,
|
||||
status="incomplete",
|
||||
)
|
||||
assert should_continue_final_message([reasoning_item]) is True
|
||||
|
||||
reasoning_item = {
|
||||
"id": "reasoning_123",
|
||||
"summary": [],
|
||||
"type": "reasoning",
|
||||
"content": [],
|
||||
"status": "incomplete",
|
||||
}
|
||||
assert should_continue_final_message([reasoning_item]) is True
|
||||
|
||||
def test_completed_reasoning_returns_false(self):
|
||||
"""Completed reasoning should not be continued."""
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id="reasoning_123",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text="I have thought about this.",
|
||||
type="reasoning_text",
|
||||
)
|
||||
],
|
||||
encrypted_content=None,
|
||||
status="completed",
|
||||
)
|
||||
assert should_continue_final_message([reasoning_item]) is False
|
||||
|
||||
def test_reasoning_with_none_status_returns_false(self):
|
||||
"""Reasoning with None status should not be continued."""
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id="reasoning_123",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
content=[
|
||||
Content(
|
||||
text="Some reasoning",
|
||||
type="reasoning_text",
|
||||
)
|
||||
],
|
||||
encrypted_content=None,
|
||||
status=None,
|
||||
)
|
||||
assert should_continue_final_message([reasoning_item]) is False
|
||||
|
||||
def test_only_last_item_matters(self):
|
||||
"""Only the last item in the list determines continuation."""
|
||||
completed_item = ResponseOutputMessage(
|
||||
id="msg_1",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
annotations=[],
|
||||
text="Complete message.",
|
||||
type="output_text",
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
in_progress_item = ResponseOutputMessage(
|
||||
id="msg_2",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
annotations=[],
|
||||
text="Partial message...",
|
||||
type="output_text",
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="in_progress",
|
||||
type="message",
|
||||
)
|
||||
|
||||
# In-progress as last item -> should continue
|
||||
assert should_continue_final_message([completed_item, in_progress_item]) is True
|
||||
|
||||
# Completed as last item -> should not continue
|
||||
assert (
|
||||
should_continue_final_message([in_progress_item, completed_item]) is False
|
||||
)
|
||||
|
||||
def test_tool_call_returns_false(self):
|
||||
"""Tool calls should not trigger continuation."""
|
||||
tool_call = ResponseFunctionToolCall(
|
||||
id="fc_123",
|
||||
call_id="call_123",
|
||||
type="function_call",
|
||||
status="in_progress",
|
||||
name="get_weather",
|
||||
arguments='{"location": "NYC"}',
|
||||
)
|
||||
assert should_continue_final_message([tool_call]) is False
|
||||
|
||||
tool_call = {
|
||||
"id": "msg_123",
|
||||
"call_id": "call_123",
|
||||
"type": "function_call",
|
||||
"status": "in_progress",
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "NYC"}',
|
||||
}
|
||||
assert should_continue_final_message([tool_call]) is False
|
||||
|
||||
# Tests for dict inputs (e.g., from curl requests)
|
||||
def test_dict_in_progress_message_returns_true(self):
|
||||
"""Dict with in_progress status should be continued (curl input)."""
|
||||
dict_item = {
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "in_progress",
|
||||
"content": [{"type": "output_text", "text": "The answer is ("}],
|
||||
}
|
||||
assert should_continue_final_message([dict_item]) is True
|
||||
|
||||
def test_dict_incomplete_message_returns_true(self):
|
||||
"""Dict with incomplete status should be continued (curl input)."""
|
||||
dict_item = {
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "incomplete",
|
||||
"content": [{"type": "output_text", "text": "Partial answer"}],
|
||||
}
|
||||
assert should_continue_final_message([dict_item]) is True
|
||||
|
||||
def test_dict_completed_message_returns_false(self):
|
||||
"""Dict with completed status should not be continued (curl input)."""
|
||||
dict_item = {
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": [{"type": "output_text", "text": "Complete answer."}],
|
||||
}
|
||||
assert should_continue_final_message([dict_item]) is False
|
||||
|
||||
def test_dict_reasoning_in_progress_returns_true(self):
|
||||
"""Dict reasoning item with in_progress status should be continued."""
|
||||
dict_item = {
|
||||
"id": "reasoning_123",
|
||||
"type": "reasoning",
|
||||
"status": "in_progress",
|
||||
"content": [{"type": "reasoning_text", "text": "Let me think..."}],
|
||||
}
|
||||
assert should_continue_final_message([dict_item]) is True
|
||||
|
||||
def test_dict_without_status_returns_false(self):
|
||||
"""Dict without status field should not be continued."""
|
||||
dict_item = {
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "Some text"}],
|
||||
}
|
||||
assert should_continue_final_message([dict_item]) is False
|
||||
|
||||
def test_dict_with_none_status_returns_false(self):
|
||||
"""Dict with None status should not be continued."""
|
||||
dict_item = {
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": None,
|
||||
"content": [{"type": "output_text", "text": "Some text"}],
|
||||
}
|
||||
assert should_continue_final_message([dict_item]) is False
|
||||
|
||||
|
||||
class TestMaybeCombineReasoningAndToolCall:
|
||||
"""Tests for _maybe_combine_reasoning_and_tool_call function."""
|
||||
|
||||
def test_returns_none_when_item_id_is_none(self):
|
||||
"""
|
||||
Test fix from PR #31999: when item.id is None, should return None
|
||||
instead of raising TypeError on startswith().
|
||||
"""
|
||||
item = ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
id=None, # This was causing TypeError before the fix
|
||||
call_id="call_123",
|
||||
name="test_function",
|
||||
arguments="{}",
|
||||
)
|
||||
messages: list[ChatCompletionMessageParam] = []
|
||||
|
||||
result = _maybe_combine_reasoning_and_tool_call(item, messages)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_id_does_not_start_with_mcp_prefix(self):
|
||||
"""Test that non-MCP tool calls are not combined."""
|
||||
item = ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
id="regular_id", # Does not start with MCP_PREFIX
|
||||
call_id="call_123",
|
||||
name="test_function",
|
||||
arguments="{}",
|
||||
)
|
||||
messages = [{"role": "assistant", "reasoning": "some reasoning"}]
|
||||
|
||||
result = _maybe_combine_reasoning_and_tool_call(item, messages)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_last_message_is_not_assistant(self):
|
||||
"""Test that non-assistant last message returns None."""
|
||||
item = ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
id=f"{MCP_PREFIX}tool_id",
|
||||
call_id="call_123",
|
||||
name="test_function",
|
||||
arguments="{}",
|
||||
)
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
|
||||
result = _maybe_combine_reasoning_and_tool_call(item, messages)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_last_message_has_no_reasoning(self):
|
||||
"""Test that assistant message without reasoning returns None."""
|
||||
item = ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
id=f"{MCP_PREFIX}tool_id",
|
||||
call_id="call_123",
|
||||
name="test_function",
|
||||
arguments="{}",
|
||||
)
|
||||
messages = [{"role": "assistant", "content": "some content"}]
|
||||
|
||||
result = _maybe_combine_reasoning_and_tool_call(item, messages)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_combines_reasoning_and_mcp_tool_call(self):
|
||||
"""Test successful combination of reasoning message and MCP tool call."""
|
||||
item = ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
id=f"{MCP_PREFIX}tool_id",
|
||||
call_id="call_123",
|
||||
name="test_function",
|
||||
arguments='{"arg": "value"}',
|
||||
)
|
||||
messages = [{"role": "assistant", "reasoning": "I need to call this tool"}]
|
||||
|
||||
result = _maybe_combine_reasoning_and_tool_call(item, messages)
|
||||
|
||||
assert result is not None
|
||||
assert result["role"] == "assistant"
|
||||
assert result["reasoning"] == "I need to call this tool"
|
||||
assert "tool_calls" in result
|
||||
assert len(result["tool_calls"]) == 1
|
||||
assert result["tool_calls"][0]["id"] == "call_123"
|
||||
assert result["tool_calls"][0]["function"]["name"] == "test_function"
|
||||
assert result["tool_calls"][0]["function"]["arguments"] == '{"arg": "value"}'
|
||||
assert result["tool_calls"][0]["type"] == "function"
|
||||
|
||||
def test_returns_none_for_non_function_tool_call_type(self):
|
||||
"""Test that non-ResponseFunctionToolCall items return None."""
|
||||
# Pass a dict instead of ResponseFunctionToolCall
|
||||
item = {"type": "message", "content": "hello"}
|
||||
messages = [{"role": "assistant", "reasoning": "some reasoning"}]
|
||||
|
||||
result = _maybe_combine_reasoning_and_tool_call(item, messages)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_id_is_empty_string(self):
|
||||
"""Test that empty string id returns None (falsy check)."""
|
||||
item = ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
id="", # Empty string is falsy
|
||||
call_id="call_123",
|
||||
name="test_function",
|
||||
arguments="{}",
|
||||
)
|
||||
messages = [{"role": "assistant", "reasoning": "some reasoning"}]
|
||||
|
||||
result = _maybe_combine_reasoning_and_tool_call(item, messages)
|
||||
|
||||
assert result is None
|
||||
@@ -1,223 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Embedding shape validation in multimodal APIs.
|
||||
|
||||
Tests verify that embeddings with correct ndim but incorrect hidden_size
|
||||
are rejected before they can cause crashes during model inference.
|
||||
|
||||
Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems
|
||||
classes, not by MediaIO classes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.parse import (
|
||||
AudioEmbeddingItems,
|
||||
ImageEmbeddingItems,
|
||||
MultiModalDataParser,
|
||||
VideoEmbeddingItems,
|
||||
)
|
||||
|
||||
|
||||
class TestMultiModalParserShapeValidation:
|
||||
"""Test hidden_size validation in MultiModalDataParser."""
|
||||
|
||||
def test_image_embeddings_correct_hidden_size_accepted(self):
|
||||
"""Baseline: Image embeddings with correct hidden_size should work."""
|
||||
expected_hidden_size = 768
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
valid_embeds = torch.randn(2, 100, expected_hidden_size)
|
||||
|
||||
result = parser.parse_mm_data({"image": valid_embeds})
|
||||
|
||||
assert "image" in result
|
||||
assert isinstance(result["image"], ImageEmbeddingItems)
|
||||
assert result["image"].get_count() == 2
|
||||
|
||||
def test_image_embeddings_wrong_hidden_size_rejected(self):
|
||||
"""Security: Image embeddings with wrong hidden_size should be rejected."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 4096
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"image": invalid_embeds})
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "image" in error_msg
|
||||
assert "hidden dimension mismatch" in error_msg
|
||||
|
||||
def test_audio_embeddings_wrong_hidden_size_rejected(self):
|
||||
"""Security: Audio embeddings with wrong hidden_size should be rejected."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 2048
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"audio": invalid_embeds})
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "audio" in error_msg
|
||||
assert "hidden dimension mismatch" in error_msg
|
||||
|
||||
def test_video_embeddings_wrong_hidden_size_rejected(self):
|
||||
"""Security: Video embeddings with wrong hidden_size should be rejected."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 512
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"video": invalid_embeds})
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "video" in error_msg
|
||||
assert "hidden dimension mismatch" in error_msg
|
||||
|
||||
def test_list_of_embeddings_validates_each(self):
|
||||
"""Security: Each embedding in list should be validated."""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 1024
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
# List with second tensor having wrong hidden_size
|
||||
invalid_embeds = [
|
||||
torch.randn(100, expected_hidden_size),
|
||||
torch.randn(100, wrong_hidden_size),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
parser.parse_mm_data({"image": invalid_embeds})
|
||||
|
||||
# Should identify which embedding failed
|
||||
assert "[1]" in str(exc_info.value)
|
||||
|
||||
def test_validation_disabled_allows_any_size(self):
|
||||
"""When validation disabled (legacy), any hidden_size allowed."""
|
||||
parser = MultiModalDataParser(expected_hidden_size=None)
|
||||
|
||||
any_hidden_size = 12345
|
||||
embeds = torch.randn(2, 100, any_hidden_size)
|
||||
|
||||
# Should not raise
|
||||
result = parser.parse_mm_data({"image": embeds})
|
||||
assert "image" in result
|
||||
assert isinstance(result["image"], ImageEmbeddingItems)
|
||||
|
||||
|
||||
class TestEmbeddingItemsDirectValidation:
|
||||
"""Direct tests for EmbeddingItems hidden_size validation."""
|
||||
|
||||
def test_image_embedding_items_validates_batched_tensor(self):
|
||||
"""Test validation for batched (3D) image embeddings."""
|
||||
expected = 768
|
||||
wrong = 1024
|
||||
|
||||
# Valid
|
||||
valid = torch.randn(2, 100, expected)
|
||||
items = ImageEmbeddingItems(valid, expected_hidden_size=expected)
|
||||
assert items.get_count() == 2
|
||||
|
||||
# Invalid
|
||||
invalid = torch.randn(2, 100, wrong)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(invalid, expected_hidden_size=expected)
|
||||
|
||||
assert str(wrong) in str(exc_info.value)
|
||||
assert str(expected) in str(exc_info.value)
|
||||
|
||||
def test_image_embedding_items_validates_list_of_tensors(self):
|
||||
"""Test validation for list of 2D image embeddings."""
|
||||
expected = 768
|
||||
wrong = 512
|
||||
|
||||
# Valid list
|
||||
valid_list = [torch.randn(100, expected), torch.randn(50, expected)]
|
||||
items = ImageEmbeddingItems(valid_list, expected_hidden_size=expected)
|
||||
assert items.get_count() == 2
|
||||
|
||||
# Invalid list
|
||||
invalid_list = [torch.randn(100, expected), torch.randn(50, wrong)]
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ImageEmbeddingItems(invalid_list, expected_hidden_size=expected)
|
||||
|
||||
assert "[1]" in str(exc_info.value)
|
||||
|
||||
def test_audio_embedding_items_validates(self):
|
||||
"""Test validation for audio embeddings."""
|
||||
expected = 768
|
||||
wrong = 256
|
||||
|
||||
invalid = torch.randn(2, 100, wrong)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AudioEmbeddingItems(invalid, expected_hidden_size=expected)
|
||||
|
||||
assert "audio" in str(exc_info.value).lower()
|
||||
|
||||
def test_video_embedding_items_validates(self):
|
||||
"""Test validation for video embeddings."""
|
||||
expected = 768
|
||||
wrong = 384
|
||||
|
||||
invalid = torch.randn(2, 100, wrong)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
VideoEmbeddingItems(invalid, expected_hidden_size=expected)
|
||||
|
||||
assert "video" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestShapeValidationIntegration:
|
||||
"""Integration tests verifying attack scenarios are blocked."""
|
||||
|
||||
def test_attack_scenario_multimodal_image(self):
|
||||
"""
|
||||
Simulate attack through Chat API with image embeddings.
|
||||
|
||||
Verifies validation occurs in multimodal parser path.
|
||||
"""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 4096
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse_mm_data({"image": attack_tensor})
|
||||
|
||||
def test_attack_scenario_multimodal_audio(self):
|
||||
"""
|
||||
Simulate attack through Chat API with audio embeddings.
|
||||
|
||||
Verifies validation occurs in multimodal parser path.
|
||||
"""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 2048
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse_mm_data({"audio": attack_tensor})
|
||||
|
||||
def test_attack_scenario_multimodal_video(self):
|
||||
"""
|
||||
Simulate attack through Chat API with video embeddings.
|
||||
|
||||
Verifies validation occurs in multimodal parser path.
|
||||
"""
|
||||
expected_hidden_size = 768
|
||||
wrong_hidden_size = 1024
|
||||
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
|
||||
|
||||
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parser.parse_mm_data({"video": attack_tensor})
|
||||
@@ -1,193 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""E2E tests for render endpoints via `vllm launch` (GPU-less serving)."""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...utils import RemoteLaunchRenderServer
|
||||
|
||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args: list[str] = []
|
||||
with RemoteLaunchRenderServer(MODEL_NAME, args, max_wait_seconds=120) as srv:
|
||||
yield srv
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with httpx.AsyncClient(
|
||||
base_url=server.url_for(""), timeout=30.0
|
||||
) as http_client:
|
||||
yield http_client
|
||||
|
||||
|
||||
# -- Chat Completion Render --
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_render_basic(client):
|
||||
response = await client.post(
|
||||
"/v1/chat/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Response should be a GenerateRequest dict
|
||||
assert isinstance(data, dict)
|
||||
assert "token_ids" in data
|
||||
assert isinstance(data["token_ids"], list)
|
||||
assert len(data["token_ids"]) > 0
|
||||
assert all(isinstance(t, int) for t in data["token_ids"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_render_multi_turn(client):
|
||||
response = await client.post(
|
||||
"/v1/chat/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert isinstance(data, dict)
|
||||
assert "token_ids" in data
|
||||
assert isinstance(data["token_ids"], list)
|
||||
assert len(data["token_ids"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_render_invalid_model(client):
|
||||
response = await client.post(
|
||||
"/v1/chat/completions/render",
|
||||
json={
|
||||
"model": "nonexistent-model",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "error" in response.json()
|
||||
|
||||
|
||||
# -- Completion Render --
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_render_basic(client):
|
||||
response = await client.post(
|
||||
"/v1/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "Once upon a time",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert isinstance(data, list)
|
||||
assert len(data) > 0
|
||||
|
||||
first_prompt = data[0]
|
||||
assert "token_ids" in first_prompt
|
||||
assert "sampling_params" in first_prompt
|
||||
assert "model" in first_prompt
|
||||
assert "request_id" in first_prompt
|
||||
assert isinstance(first_prompt["token_ids"], list)
|
||||
assert len(first_prompt["token_ids"]) > 0
|
||||
assert first_prompt["request_id"].startswith("cmpl-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_render_multiple_prompts(client):
|
||||
response = await client.post(
|
||||
"/v1/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"prompt": ["Hello world", "Goodbye world"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
|
||||
for prompt in data:
|
||||
assert "token_ids" in prompt
|
||||
assert "sampling_params" in prompt
|
||||
assert "model" in prompt
|
||||
assert "request_id" in prompt
|
||||
assert len(prompt["token_ids"]) > 0
|
||||
assert prompt["request_id"].startswith("cmpl-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_render_invalid_model(client):
|
||||
response = await client.post(
|
||||
"/v1/completions/render",
|
||||
json={
|
||||
"model": "nonexistent-model",
|
||||
"prompt": "Hello",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "error" in response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_render_is_fast(client):
|
||||
"""Render should complete quickly since there is no inference."""
|
||||
import time
|
||||
|
||||
start = time.perf_counter()
|
||||
response = await client.post(
|
||||
"/v1/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "Tell me a very long story about " * 10,
|
||||
},
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
assert response.status_code == 200
|
||||
assert elapsed < 2.0
|
||||
|
||||
|
||||
# -- Health & Models --
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(client):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_endpoint(client):
|
||||
response = await client.get("/v1/models")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
model_ids = [m["id"] for m in data["data"]]
|
||||
assert MODEL_NAME in model_ids
|
||||
@@ -1,370 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import shutil
|
||||
from contextlib import suppress
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
BADREQUEST_CASES = [
|
||||
(
|
||||
"test_rank",
|
||||
{"r": 1024},
|
||||
"is greater than max_lora_rank",
|
||||
),
|
||||
("test_dora", {"use_dora": True}, "does not yet support DoRA"),
|
||||
(
|
||||
"test_modules_to_save",
|
||||
{"modules_to_save": ["lm_head"]},
|
||||
"only supports modules_to_save being None",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[True])
|
||||
def server_with_lora_modules_json(request, qwen3_lora_files):
|
||||
# Define the json format LoRA module configurations
|
||||
lora_module_1 = {
|
||||
"name": "qwen3-lora",
|
||||
"path": qwen3_lora_files,
|
||||
"base_model_name": MODEL_NAME,
|
||||
}
|
||||
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
# lora config below
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
json.dumps(lora_module_1),
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
"--max-num-seqs",
|
||||
"64",
|
||||
]
|
||||
|
||||
# Enable the /v1/load_lora_adapter endpoint
|
||||
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server_with_lora_modules_json):
|
||||
async with server_with_lora_modules_json.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_lora_lineage(client: openai.AsyncOpenAI, qwen3_lora_files):
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
lora_models = models[1:]
|
||||
assert served_model.id == MODEL_NAME
|
||||
assert served_model.root == MODEL_NAME
|
||||
assert served_model.parent is None
|
||||
assert all(lora_model.root == qwen3_lora_files for lora_model in lora_models)
|
||||
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
|
||||
assert lora_models[0].id == "qwen3-lora"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, qwen3_lora_files):
|
||||
response = await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "qwen3-lora-3", "lora_path": qwen3_lora_files},
|
||||
)
|
||||
# Ensure adapter loads before querying /models
|
||||
assert "success" in response
|
||||
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
dynamic_lora_model = models[-1]
|
||||
assert dynamic_lora_model.root == qwen3_lora_files
|
||||
assert dynamic_lora_model.parent == MODEL_NAME
|
||||
assert dynamic_lora_model.id == "qwen3-lora-3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_with_same_name_replaces_inplace(
|
||||
client: openai.AsyncOpenAI, qwen3_meowing_lora_files, qwen3_woofing_lora_files
|
||||
):
|
||||
"""Test that loading a LoRA adapter with the same name replaces it inplace."""
|
||||
adapter_name = "replaceable-adapter"
|
||||
messages = [
|
||||
{"content": "Follow the instructions to make animal noises", "role": "system"},
|
||||
{"content": "Make your favorite animal noise.", "role": "user"},
|
||||
]
|
||||
|
||||
# Load LoRA that makes model meow
|
||||
response = await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": adapter_name, "lora_path": qwen3_meowing_lora_files},
|
||||
)
|
||||
assert "success" in response.lower()
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=adapter_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
assert "Meow Meow Meow" in completion.choices[0].message.content
|
||||
|
||||
# Load LoRA that makes model woof
|
||||
response = await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": adapter_name,
|
||||
"lora_path": qwen3_woofing_lora_files,
|
||||
"load_inplace": True,
|
||||
},
|
||||
)
|
||||
assert "success" in response.lower()
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=adapter_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
assert "Woof Woof Woof" in completion.choices[0].message.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_with_load_inplace_false_errors(
|
||||
client: openai.AsyncOpenAI, qwen3_meowing_lora_files
|
||||
):
|
||||
"""Test that load_inplace=False returns an error when adapter already exists."""
|
||||
adapter_name = "test-load-inplace-false"
|
||||
|
||||
# Load LoRA adapter first time (should succeed)
|
||||
response = await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": adapter_name, "lora_path": qwen3_meowing_lora_files},
|
||||
)
|
||||
assert "success" in response.lower()
|
||||
|
||||
# Try to load the same adapter again with load_inplace=False (should fail)
|
||||
with pytest.raises(openai.BadRequestError) as exc_info:
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": adapter_name,
|
||||
"lora_path": qwen3_meowing_lora_files,
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the error message
|
||||
assert "already been loaded" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI):
|
||||
with pytest.raises(openai.NotFoundError):
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "notfound", "lora_path": "/not/an/adapter"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, tmp_path):
|
||||
invalid_files = tmp_path / "invalid_files"
|
||||
invalid_files.mkdir()
|
||||
(invalid_files / "adapter_config.json").write_text("this is not json")
|
||||
|
||||
with pytest.raises(openai.InternalServerError):
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "invalid-json", "lora_path": str(invalid_files)},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("test_name,config_change,expected_error", BADREQUEST_CASES)
|
||||
async def test_dynamic_lora_badrequests(
|
||||
client: openai.AsyncOpenAI,
|
||||
tmp_path,
|
||||
qwen3_lora_files,
|
||||
test_name: str,
|
||||
config_change: dict,
|
||||
expected_error: str,
|
||||
):
|
||||
# Create test directory
|
||||
test_dir = tmp_path / test_name
|
||||
|
||||
# Copy adapter files
|
||||
shutil.copytree(qwen3_lora_files, test_dir)
|
||||
|
||||
# Load and modify configuration
|
||||
config_path = test_dir / "adapter_config.json"
|
||||
with open(config_path) as f:
|
||||
adapter_config = json.load(f)
|
||||
# Apply configuration changes
|
||||
adapter_config.update(config_change)
|
||||
|
||||
# Save modified configuration
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(adapter_config, f)
|
||||
|
||||
# Test loading the adapter
|
||||
with pytest.raises(openai.InternalServerError, match=expected_error):
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": test_name, "lora_path": str(test_dir)},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_lora_adapters(
|
||||
client: openai.AsyncOpenAI, tmp_path, qwen3_lora_files
|
||||
):
|
||||
"""Validate that many loras can be dynamically registered and inferenced
|
||||
with concurrently"""
|
||||
|
||||
# This test file configures the server with --max-cpu-loras=2 and this test
|
||||
# will concurrently load 10 adapters, so it should flex the LRU cache
|
||||
async def load_and_run_adapter(adapter_name: str):
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": adapter_name, "lora_path": str(qwen3_lora_files)},
|
||||
)
|
||||
for _ in range(3):
|
||||
await client.completions.create(
|
||||
model=adapter_name,
|
||||
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||
max_tokens=5,
|
||||
)
|
||||
|
||||
lora_tasks = []
|
||||
for i in range(10):
|
||||
lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
|
||||
|
||||
results, _ = await asyncio.wait(lora_tasks)
|
||||
|
||||
for r in results:
|
||||
assert not isinstance(r, Exception), f"Got exception {r}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loading_invalid_adapters_does_not_break_others(
|
||||
client: openai.AsyncOpenAI, tmp_path, qwen3_lora_files
|
||||
):
|
||||
invalid_files = tmp_path / "invalid_files"
|
||||
invalid_files.mkdir()
|
||||
(invalid_files / "adapter_config.json").write_text("this is not json")
|
||||
|
||||
stop_good_requests_event = asyncio.Event()
|
||||
|
||||
async def run_good_requests(client):
|
||||
# Run chat completions requests until event set
|
||||
|
||||
results = []
|
||||
|
||||
while not stop_good_requests_event.is_set():
|
||||
try:
|
||||
batch = await client.completions.create(
|
||||
model="qwen3-lora",
|
||||
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||
max_tokens=5,
|
||||
)
|
||||
results.append(batch)
|
||||
except Exception as e:
|
||||
results.append(e)
|
||||
|
||||
return results
|
||||
|
||||
# Create task to run good requests
|
||||
good_task = asyncio.create_task(run_good_requests(client))
|
||||
|
||||
# Run a bunch of bad adapter loads
|
||||
for _ in range(25):
|
||||
with suppress(openai.NotFoundError):
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "notfound", "lora_path": "/not/an/adapter"},
|
||||
)
|
||||
for _ in range(25):
|
||||
with suppress(openai.InternalServerError):
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "invalid", "lora_path": str(invalid_files)},
|
||||
)
|
||||
|
||||
# Ensure all the running requests with lora adapters succeeded
|
||||
stop_good_requests_event.set()
|
||||
results = await good_task
|
||||
for r in results:
|
||||
assert not isinstance(r, Exception), f"Got exception {r}"
|
||||
|
||||
# Ensure we can load another adapter and run it
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": "valid", "lora_path": qwen3_lora_files},
|
||||
)
|
||||
await client.completions.create(
|
||||
model="valid",
|
||||
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||
max_tokens=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_beam_search_with_lora_adapters(
|
||||
client: openai.AsyncOpenAI,
|
||||
tmp_path,
|
||||
qwen3_lora_files,
|
||||
):
|
||||
"""Validate that async beam search can be used with lora."""
|
||||
|
||||
async def load_and_run_adapter(adapter_name: str):
|
||||
await client.post(
|
||||
"load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={"lora_name": adapter_name, "lora_path": str(qwen3_lora_files)},
|
||||
)
|
||||
for _ in range(3):
|
||||
await client.completions.create(
|
||||
model=adapter_name,
|
||||
prompt=["Hello there", "Foo bar bazz buzz"],
|
||||
max_tokens=5,
|
||||
extra_body=dict(use_beam_search=True),
|
||||
)
|
||||
|
||||
lora_tasks = []
|
||||
for i in range(3):
|
||||
lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
|
||||
|
||||
results, _ = await asyncio.wait(lora_tasks)
|
||||
|
||||
for r in results:
|
||||
assert not isinstance(r, Exception), f"Got exception {r}"
|
||||
@@ -1,133 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.serve.lora.protocol import (
|
||||
LoadLoRAAdapterRequest,
|
||||
UnloadLoRAAdapterRequest,
|
||||
)
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||
LORA_LOADING_SUCCESS_MESSAGE = "Success: LoRA adapter '{lora_name}' added successfully."
|
||||
LORA_UNLOADING_SUCCESS_MESSAGE = (
|
||||
"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
)
|
||||
|
||||
|
||||
async def _async_serving_models_init() -> OpenAIServingModels:
|
||||
mock_engine_client = MagicMock(spec=EngineClient)
|
||||
# Set the max_model_len attribute to avoid missing attribute
|
||||
mock_model_config = MagicMock(spec=ModelConfig)
|
||||
mock_model_config.max_model_len = 2048
|
||||
mock_engine_client.model_config = mock_model_config
|
||||
mock_engine_client.input_processor = MagicMock()
|
||||
mock_engine_client.io_processor = MagicMock()
|
||||
mock_engine_client.renderer = MagicMock()
|
||||
|
||||
serving_models = OpenAIServingModels(
|
||||
engine_client=mock_engine_client,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
lora_modules=None,
|
||||
)
|
||||
await serving_models.init_static_loras()
|
||||
|
||||
return serving_models
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_model_name():
|
||||
serving_models = await _async_serving_models_init()
|
||||
assert serving_models.model_name(None) == MODEL_NAME
|
||||
request = LoRARequest(
|
||||
lora_name="adapter", lora_path="/path/to/adapter2", lora_int_id=1
|
||||
)
|
||||
assert serving_models.model_name(request) == request.lora_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_success():
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoRAAdapterRequest(lora_name="adapter", lora_path="/path/to/adapter2")
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter")
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
assert "adapter" in serving_models.lora_requests
|
||||
assert serving_models.lora_requests["adapter"].lora_name == "adapter"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_missing_fields():
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InvalidUserInput"
|
||||
assert response.error.code == HTTPStatus.BAD_REQUEST
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_duplicate():
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoRAAdapterRequest(
|
||||
lora_name="adapter1", lora_path="/path/to/adapter1"
|
||||
)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter1")
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
request = LoadLoRAAdapterRequest(
|
||||
lora_name="adapter1", lora_path="/path/to/adapter1"
|
||||
)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InvalidUserInput"
|
||||
assert response.error.code == HTTPStatus.BAD_REQUEST
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unload_lora_adapter_success():
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoRAAdapterRequest(
|
||||
lora_name="adapter1", lora_path="/path/to/adapter1"
|
||||
)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
request = UnloadLoRAAdapterRequest(lora_name="adapter1")
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(lora_name="adapter1")
|
||||
assert len(serving_models.lora_requests) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unload_lora_adapter_missing_fields():
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InvalidUserInput"
|
||||
assert response.error.code == HTTPStatus.BAD_REQUEST
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unload_lora_adapter_not_found():
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "NotFoundError"
|
||||
assert response.error.code == HTTPStatus.NOT_FOUND
|
||||
@@ -1,347 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.v1.engine.detokenizer import check_stop_strings
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
GEN_ENDPOINT = "/inference/v1/generate"
|
||||
|
||||
|
||||
def get_vocab_size(model_name):
|
||||
config = ModelConfig(
|
||||
model=model_name,
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
return config.get_vocab_size()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer():
|
||||
return AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def messages():
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "How many countries are in the EU?"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(request):
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"1024",
|
||||
"--enforce-eager",
|
||||
# On ROCm (e.g. MI355X/gfx950), bf16 GEMM results can differ by
|
||||
# 1 ULP when the batch dimension (M) changes, because different M
|
||||
# values cause the Tensile backend to select different tile
|
||||
# configurations with different fp32 accumulation orders. With
|
||||
# prefix caching, cache-miss prefills compute all tokens in one
|
||||
# pass (large M) while cache-hit requests compute only the
|
||||
# uncached suffix (small M), seeding a divergence that amplifies
|
||||
# through the residual stream and flips argmax tokens.
|
||||
# See: https://github.com/vllm-project/vllm/issues/33123
|
||||
#
|
||||
# Either disable prefix caching entirely, or enable it with
|
||||
# --deterministic-prefix-caching which forces cache-miss prefills
|
||||
# to split at block boundaries so the suffix GEMM shape is always
|
||||
# identical regardless of cache state.
|
||||
#
|
||||
# Option A: disable prefix caching
|
||||
"--no-enable-prefix-caching",
|
||||
#
|
||||
# Option B: deterministic prefix caching
|
||||
# "--enable-prefix-caching",
|
||||
# "--deterministic-prefix-caching",
|
||||
]
|
||||
|
||||
extra_args = getattr(request, "param", None)
|
||||
if extra_args is not None:
|
||||
args = args + (
|
||||
list(extra_args)
|
||||
if isinstance(extra_args, (list, tuple))
|
||||
else [str(extra_args)]
|
||||
)
|
||||
|
||||
envs = os.environ.copy()
|
||||
# See: https://github.com/vllm-project/vllm/pull/33493#issuecomment-3888060787
|
||||
envs["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server: RemoteOpenAIServer):
|
||||
transport = httpx.AsyncHTTPTransport(uds=server.uds) if server.uds else None
|
||||
headers = {"Authorization": f"Bearer {server.DUMMY_API_KEY}"}
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport,
|
||||
base_url=server.url_root,
|
||||
timeout=600,
|
||||
headers=headers,
|
||||
) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_endpoint(client):
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"token_ids": [1, 2, 3],
|
||||
"sampling_params": {"max_tokens": 5},
|
||||
"stream": False,
|
||||
}
|
||||
resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
assert "choices" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("logprobs_value", [0, 1, 5])
|
||||
async def test_generate_logprobs(client, logprobs_value):
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"token_ids": [1, 2, 3],
|
||||
"sampling_params": {
|
||||
"max_tokens": 5,
|
||||
"temperature": 0.0,
|
||||
"logprobs": logprobs_value,
|
||||
},
|
||||
"stream": False,
|
||||
}
|
||||
resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
choice = data["choices"][0]
|
||||
assert choice["logprobs"] is not None
|
||||
logprobs_content = choice["logprobs"]["content"]
|
||||
assert len(logprobs_content) == len(choice["token_ids"])
|
||||
for entry in logprobs_content:
|
||||
assert "logprob" in entry
|
||||
assert len(entry["top_logprobs"]) >= 1
|
||||
assert len(entry["top_logprobs"]) == max(logprobs_value, 1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_response_as_chat_completions(client, tokenizer, messages):
|
||||
token_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False, # default with Qwen3
|
||||
return_dict=True, # default with Transformers v5
|
||||
).input_ids
|
||||
|
||||
for ignore_eos in [True, False]:
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"token_ids": token_ids,
|
||||
"sampling_params": {
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
# NOTE coordinator will set this to skip detokenization
|
||||
"detokenize": False,
|
||||
"ignore_eos": ignore_eos,
|
||||
},
|
||||
"stream": False,
|
||||
}
|
||||
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||
generate_data = generate_resp.json()
|
||||
gen_token_ids = generate_data["choices"][0]["token_ids"]
|
||||
generate_res = tokenizer.decode(gen_token_ids, skip_special_tokens=True)
|
||||
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
"stream": False,
|
||||
"ignore_eos": ignore_eos,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
}
|
||||
completions_resp = await client.post("/v1/chat/completions", json=payload)
|
||||
completions_data = completions_resp.json()
|
||||
completions_res = completions_data["choices"][0]["message"]["content"]
|
||||
|
||||
if ignore_eos:
|
||||
# When ignoring EOS, only compare up to the first EOS token
|
||||
# Post-EOS generation is undefined and may differ
|
||||
eos_tokens = {
|
||||
tokenizer.eos_token_id,
|
||||
*getattr_iter(
|
||||
tokenizer,
|
||||
[
|
||||
"extra_special_tokens_ids", # Transformers v5
|
||||
"additional_special_tokens_ids", # Transformers v4
|
||||
],
|
||||
[],
|
||||
),
|
||||
}
|
||||
# Find first EOS in generated tokens
|
||||
eos_pos = None
|
||||
for i, tid in enumerate(gen_token_ids):
|
||||
if tid in eos_tokens:
|
||||
eos_pos = i
|
||||
break
|
||||
if eos_pos is not None:
|
||||
gen_token_ids_truncated = gen_token_ids[:eos_pos]
|
||||
generate_res = tokenizer.decode(
|
||||
gen_token_ids_truncated, skip_special_tokens=True
|
||||
)
|
||||
# Truncate completions_res to same length for comparison
|
||||
completions_res = completions_res[: len(generate_res)]
|
||||
|
||||
assert generate_res == completions_res
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_string_workflow(client, tokenizer, messages):
|
||||
token_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False, # default with Qwen3
|
||||
return_dict=True, # default with Transformers v5
|
||||
).input_ids
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"token_ids": token_ids,
|
||||
"sampling_params": {
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
"detokenize": False,
|
||||
# stop strings are only supported when detokenize is True.
|
||||
"stop": ["27 member"],
|
||||
},
|
||||
# TODO stream test is much more interesting
|
||||
"stream": False,
|
||||
}
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||
generate_resp.raise_for_status()
|
||||
|
||||
payload["sampling_params"]["stop"] = None
|
||||
generate_resp = await client.post(
|
||||
GEN_ENDPOINT, json=payload, headers={"X-Request-Id": "42"}
|
||||
)
|
||||
generate_data = generate_resp.json()
|
||||
generate_res = tokenizer.decode(
|
||||
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
|
||||
)
|
||||
|
||||
# NOTE This is under the responsibility of the coordinator
|
||||
# stop_checker = StopChecker(
|
||||
# max_model_len=1024, get_tokenizer_for_seq=lambda _: tokenizer
|
||||
# )
|
||||
stop_str, truncate_to = check_stop_strings(
|
||||
generate_res, len(generate_res), ["27 member"], False
|
||||
)
|
||||
assert stop_str == "27 member"
|
||||
# abort request that hit stop string (requires tokens-only mode)
|
||||
# res = await client.post("/abort_requests", json={"request_ids": ["generate-tokens-42"]}) # noqa: E501
|
||||
# res.raise_for_status()
|
||||
generate_res = generate_res[:truncate_to]
|
||||
|
||||
# Get stop_str response from chat completions
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
"stream": False,
|
||||
"stop": ["27 member"],
|
||||
"chat_template_kwargs": dict(enable_thinking=False),
|
||||
}
|
||||
completions_resp = await client.post("/v1/chat/completions", json=payload)
|
||||
completions_data = completions_resp.json()
|
||||
completions_res = completions_data["choices"][0]["message"]["content"]
|
||||
assert generate_res == completions_res
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"server",
|
||||
[
|
||||
[
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
"Alice=charent/self_cognition_Alice",
|
||||
"Bob=charent/self_cognition_Bob",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
]
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_generate_with_lora_adapter(client, tokenizer, messages):
|
||||
# Verify adapters are listed
|
||||
models_resp = await client.get("/v1/models")
|
||||
models_resp.raise_for_status()
|
||||
models = {m["id"] for m in models_resp.json().get("data", [])}
|
||||
assert {"Alice", "Bob"}.issubset(models)
|
||||
|
||||
# Generate using a LoRA adapter by specifying its name as the model
|
||||
payload = {
|
||||
"model": "Alice",
|
||||
"token_ids": [1, 2, 3],
|
||||
"sampling_params": {"max_tokens": 5},
|
||||
"stream": False,
|
||||
}
|
||||
resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
assert "choices" in data
|
||||
|
||||
token_ids = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False, # default with Qwen3
|
||||
return_dict=True, # default with Transformers v5
|
||||
).input_ids
|
||||
payload = {
|
||||
"model": "Alice",
|
||||
"token_ids": token_ids,
|
||||
"sampling_params": {
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
"detokenize": False,
|
||||
},
|
||||
"stream": False,
|
||||
}
|
||||
generate_resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||
generate_data = generate_resp.json()
|
||||
generate_res = tokenizer.decode(
|
||||
generate_data["choices"][0]["token_ids"], skip_special_tokens=True
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": "Alice",
|
||||
"messages": messages,
|
||||
"max_tokens": 24,
|
||||
"temperature": 0.0,
|
||||
"stream": False,
|
||||
"chat_template_kwargs": dict(enable_thinking=False),
|
||||
}
|
||||
completions_resp = await client.post("/v1/chat/completions", json=payload)
|
||||
completions_data = completions_resp.json()
|
||||
completions_res = completions_data["choices"][0]["message"]["content"]
|
||||
|
||||
assert generate_res == completions_res
|
||||
Reference in New Issue
Block a user