373 lines
11 KiB
Python
373 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass, field
|
|
from http import HTTPStatus
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from vllm.config.multimodal import MultiModalConfig
|
|
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
|
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
|
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
|
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
|
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
|
from vllm.outputs import CompletionOutput, RequestOutput
|
|
from vllm.renderers.hf import HfRenderer
|
|
from vllm.tokenizers.registry import tokenizer_args_from_config
|
|
from vllm.v1.engine.async_llm import AsyncLLM
|
|
|
|
MODEL_NAME = "openai-community/gpt2"
|
|
MODEL_NAME_SHORT = "gpt2"
|
|
BASE_MODEL_PATHS = [
|
|
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
|
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class MockHFConfig:
|
|
model_type: str = "any"
|
|
|
|
|
|
@dataclass
|
|
class MockModelConfig:
|
|
task = "generate"
|
|
runner_type = "generate"
|
|
model = MODEL_NAME
|
|
tokenizer = MODEL_NAME
|
|
trust_remote_code = False
|
|
tokenizer_mode = "auto"
|
|
max_model_len = 100
|
|
tokenizer_revision = None
|
|
multimodal_config = MultiModalConfig()
|
|
hf_config = MockHFConfig()
|
|
hf_text_config = MockHFConfig()
|
|
logits_processors: list[str] | None = None
|
|
diff_sampling_param: dict | None = None
|
|
allowed_local_media_path: str = ""
|
|
allowed_media_domains: list[str] | None = None
|
|
encoder_config = None
|
|
generation_config: str = "auto"
|
|
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
skip_tokenizer_init = False
|
|
is_encoder_decoder: bool = False
|
|
is_multimodal_model: bool = False
|
|
|
|
def get_diff_sampling_param(self):
|
|
return self.diff_sampling_param or {}
|
|
|
|
|
|
@dataclass
|
|
class MockVllmConfig:
|
|
model_config: MockModelConfig
|
|
|
|
|
|
def _build_renderer(model_config: MockModelConfig):
|
|
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
|
|
|
return HfRenderer.from_config(
|
|
MockVllmConfig(model_config),
|
|
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
|
)
|
|
|
|
|
|
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
|
models = OpenAIServingModels(
|
|
engine_client=engine,
|
|
base_model_paths=BASE_MODEL_PATHS,
|
|
)
|
|
serving_chat = OpenAIServingChat(
|
|
engine,
|
|
models,
|
|
response_role="assistant",
|
|
request_logger=None,
|
|
chat_template=None,
|
|
chat_template_content_format="auto",
|
|
)
|
|
|
|
async def _fake_preprocess_chat(*args, **kwargs):
|
|
# return conversation, engine_prompts
|
|
return (
|
|
[{"role": "user", "content": "Test"}],
|
|
[{"prompt_token_ids": [1, 2, 3]}],
|
|
)
|
|
|
|
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
|
|
return serving_chat
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_error_non_stream():
|
|
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
|
mock_engine = MagicMock(spec=AsyncLLM)
|
|
mock_engine.errored = False
|
|
mock_engine.model_config = MockModelConfig()
|
|
mock_engine.input_processor = MagicMock()
|
|
mock_engine.io_processor = MagicMock()
|
|
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
|
|
|
serving_chat = _build_serving_chat(mock_engine)
|
|
|
|
completion_output = CompletionOutput(
|
|
index=0,
|
|
text="",
|
|
token_ids=[],
|
|
cumulative_logprob=None,
|
|
logprobs=None,
|
|
finish_reason="error",
|
|
)
|
|
|
|
request_output = RequestOutput(
|
|
request_id="test-id",
|
|
prompt="Test prompt",
|
|
prompt_token_ids=[1, 2, 3],
|
|
prompt_logprobs=None,
|
|
outputs=[completion_output],
|
|
finished=True,
|
|
metrics=None,
|
|
lora_request=None,
|
|
encoder_prompt=None,
|
|
encoder_prompt_token_ids=None,
|
|
)
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield request_output
|
|
|
|
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
|
|
|
request = ChatCompletionRequest(
|
|
model=MODEL_NAME,
|
|
messages=[{"role": "user", "content": "Test prompt"}],
|
|
max_tokens=10,
|
|
stream=False,
|
|
)
|
|
|
|
response = await serving_chat.create_chat_completion(request)
|
|
|
|
assert isinstance(response, ErrorResponse)
|
|
assert response.error.type == "InternalServerError"
|
|
assert response.error.message == "Internal server error"
|
|
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_error_stream():
|
|
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
|
mock_engine = MagicMock(spec=AsyncLLM)
|
|
mock_engine.errored = False
|
|
mock_engine.model_config = MockModelConfig()
|
|
mock_engine.input_processor = MagicMock()
|
|
mock_engine.io_processor = MagicMock()
|
|
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
|
|
|
serving_chat = _build_serving_chat(mock_engine)
|
|
|
|
completion_output_1 = CompletionOutput(
|
|
index=0,
|
|
text="Hello",
|
|
token_ids=[100],
|
|
cumulative_logprob=None,
|
|
logprobs=None,
|
|
finish_reason=None,
|
|
)
|
|
|
|
request_output_1 = RequestOutput(
|
|
request_id="test-id",
|
|
prompt="Test prompt",
|
|
prompt_token_ids=[1, 2, 3],
|
|
prompt_logprobs=None,
|
|
outputs=[completion_output_1],
|
|
finished=False,
|
|
metrics=None,
|
|
lora_request=None,
|
|
encoder_prompt=None,
|
|
encoder_prompt_token_ids=None,
|
|
)
|
|
|
|
completion_output_2 = CompletionOutput(
|
|
index=0,
|
|
text="Hello",
|
|
token_ids=[100],
|
|
cumulative_logprob=None,
|
|
logprobs=None,
|
|
finish_reason="error",
|
|
)
|
|
|
|
request_output_2 = RequestOutput(
|
|
request_id="test-id",
|
|
prompt="Test prompt",
|
|
prompt_token_ids=[1, 2, 3],
|
|
prompt_logprobs=None,
|
|
outputs=[completion_output_2],
|
|
finished=True,
|
|
metrics=None,
|
|
lora_request=None,
|
|
encoder_prompt=None,
|
|
encoder_prompt_token_ids=None,
|
|
)
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield request_output_1
|
|
yield request_output_2
|
|
|
|
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
|
|
|
request = ChatCompletionRequest(
|
|
model=MODEL_NAME,
|
|
messages=[{"role": "user", "content": "Test prompt"}],
|
|
max_tokens=10,
|
|
stream=True,
|
|
)
|
|
|
|
response = await serving_chat.create_chat_completion(request)
|
|
|
|
chunks = []
|
|
async for chunk in response:
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) >= 2
|
|
assert any("Internal server error" in chunk for chunk in chunks), (
|
|
f"Expected error message in chunks: {chunks}"
|
|
)
|
|
assert chunks[-1] == "data: [DONE]\n\n"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"image_content",
|
|
[
|
|
[{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}],
|
|
[{"image_url": {"url": "https://example.com/image.jpg"}}],
|
|
],
|
|
)
|
|
def test_system_message_warns_on_image(image_content):
|
|
"""Test that system messages with image content trigger a warning."""
|
|
with patch(
|
|
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
|
) as mock_logger:
|
|
ChatCompletionRequest(
|
|
model=MODEL_NAME,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": image_content,
|
|
}
|
|
],
|
|
)
|
|
|
|
mock_logger.warning_once.assert_called()
|
|
call_args = str(mock_logger.warning_once.call_args)
|
|
assert "System messages should only contain text" in call_args
|
|
assert "image_url" in call_args
|
|
|
|
|
|
def test_system_message_accepts_text():
|
|
"""Test that system messages can contain text content."""
|
|
# Should not raise an exception
|
|
request = ChatCompletionRequest(
|
|
model=MODEL_NAME,
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
],
|
|
)
|
|
assert request.messages[0]["role"] == "system"
|
|
|
|
|
|
def test_system_message_accepts_text_array():
|
|
"""Test that system messages can contain an array with text content."""
|
|
# Should not raise an exception
|
|
request = ChatCompletionRequest(
|
|
model=MODEL_NAME,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
|
},
|
|
],
|
|
)
|
|
assert request.messages[0]["role"] == "system"
|
|
|
|
|
|
def test_user_message_accepts_image():
|
|
"""Test that user messages can still contain image content."""
|
|
# Should not raise an exception
|
|
request = ChatCompletionRequest(
|
|
model=MODEL_NAME,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "What's in this image?"},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": "https://example.com/image.jpg"},
|
|
},
|
|
],
|
|
},
|
|
],
|
|
)
|
|
assert request.messages[0]["role"] == "user"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"audio_content",
|
|
[
|
|
[
|
|
{
|
|
"type": "input_audio",
|
|
"input_audio": {"data": "base64data", "format": "wav"},
|
|
}
|
|
],
|
|
[{"input_audio": {"data": "base64data", "format": "wav"}}],
|
|
],
|
|
)
|
|
def test_system_message_warns_on_audio(audio_content):
|
|
"""Test that system messages with audio content trigger a warning."""
|
|
with patch(
|
|
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
|
) as mock_logger:
|
|
ChatCompletionRequest(
|
|
model=MODEL_NAME,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": audio_content,
|
|
}
|
|
],
|
|
)
|
|
|
|
mock_logger.warning_once.assert_called()
|
|
call_args = str(mock_logger.warning_once.call_args)
|
|
assert "System messages should only contain text" in call_args
|
|
assert "input_audio" in call_args
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"video_content",
|
|
[
|
|
[{"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}}],
|
|
[{"video_url": {"url": "https://example.com/video.mp4"}}],
|
|
],
|
|
)
|
|
def test_system_message_warns_on_video(video_content):
|
|
"""Test that system messages with video content trigger a warning."""
|
|
with patch(
|
|
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
|
) as mock_logger:
|
|
ChatCompletionRequest(
|
|
model=MODEL_NAME,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": video_content,
|
|
}
|
|
],
|
|
)
|
|
|
|
mock_logger.warning_once.assert_called()
|
|
call_args = str(mock_logger.warning_once.call_args)
|
|
assert "System messages should only contain text" in call_args
|
|
assert "video_url" in call_args
|