[Refactor] Relocate chat completion and anthropic tests (#36919)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
398
tests/entrypoints/openai/chat_completion/test_chat_error.py
Normal file
398
tests/entrypoints/openai/chat_completion/test_chat_error.py
Normal file
@@ -0,0 +1,398 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import GenerationError
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
MODEL_NAME_SHORT = "gpt2"
|
||||
BASE_MODEL_PATHS = [
|
||||
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
||||
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHFConfig:
|
||||
model_type: str = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
task = "generate"
|
||||
runner_type = "generate"
|
||||
model = MODEL_NAME
|
||||
tokenizer = MODEL_NAME
|
||||
trust_remote_code = False
|
||||
tokenizer_mode = "auto"
|
||||
max_model_len = 100
|
||||
tokenizer_revision = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
hf_text_config = MockHFConfig()
|
||||
logits_processors: list[str] | None = None
|
||||
diff_sampling_param: dict | None = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: list[str] | None = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init = False
|
||||
is_encoder_decoder: bool = False
|
||||
is_multimodal_model: bool = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockParallelConfig:
|
||||
_api_process_rank: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockVllmConfig:
|
||||
model_config: MockModelConfig
|
||||
parallel_config: MockParallelConfig
|
||||
|
||||
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer.from_config(
|
||||
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
|
||||
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
)
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
models,
|
||||
response_role="assistant",
|
||||
openai_serving_render=serving_render,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
|
||||
async def _fake_preprocess_chat(*args, **kwargs):
|
||||
# return conversation, engine_prompts
|
||||
return (
|
||||
[{"role": "user", "content": "Test"}],
|
||||
[{"prompt_token_ids": [1, 2, 3]}],
|
||||
)
|
||||
|
||||
serving_chat.openai_serving_render._preprocess_chat = AsyncMock(
|
||||
side_effect=_fake_preprocess_chat
|
||||
)
|
||||
return serving_chat
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_error_non_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Test prompt"}],
|
||||
max_tokens=10,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
with pytest.raises(GenerationError):
|
||||
await serving_chat.create_chat_completion(request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_error_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
completion_output_1 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
request_output_1 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_1],
|
||||
finished=False,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
completion_output_2 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output_2 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_2],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output_1
|
||||
yield request_output_2
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Test prompt"}],
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response = await serving_chat.create_chat_completion(request)
|
||||
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
assert any("Internal server error" in chunk for chunk in chunks), (
|
||||
f"Expected error message in chunks: {chunks}"
|
||||
)
|
||||
assert chunks[-1] == "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"image_content",
|
||||
[
|
||||
[{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}],
|
||||
[{"image_url": {"url": "https://example.com/image.jpg"}}],
|
||||
],
|
||||
)
|
||||
def test_system_message_warns_on_image(image_content):
|
||||
"""Test that system messages with image content trigger a warning."""
|
||||
with patch(
|
||||
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
||||
) as mock_logger:
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": image_content,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
mock_logger.warning_once.assert_called()
|
||||
call_args = str(mock_logger.warning_once.call_args)
|
||||
assert "System messages should only contain text" in call_args
|
||||
assert "image_url" in call_args
|
||||
|
||||
|
||||
def test_system_message_accepts_text():
|
||||
"""Test that system messages can contain text content."""
|
||||
# Should not raise an exception
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
],
|
||||
)
|
||||
assert request.messages[0]["role"] == "system"
|
||||
|
||||
|
||||
def test_system_message_accepts_text_array():
|
||||
"""Test that system messages can contain an array with text content."""
|
||||
# Should not raise an exception
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert request.messages[0]["role"] == "system"
|
||||
|
||||
|
||||
def test_user_message_accepts_image():
|
||||
"""Test that user messages can still contain image content."""
|
||||
# Should not raise an exception
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/image.jpg"},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert request.messages[0]["role"] == "user"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"audio_content",
|
||||
[
|
||||
[
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": "base64data", "format": "wav"},
|
||||
}
|
||||
],
|
||||
[{"input_audio": {"data": "base64data", "format": "wav"}}],
|
||||
],
|
||||
)
|
||||
def test_system_message_warns_on_audio(audio_content):
|
||||
"""Test that system messages with audio content trigger a warning."""
|
||||
with patch(
|
||||
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
||||
) as mock_logger:
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": audio_content,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
mock_logger.warning_once.assert_called()
|
||||
call_args = str(mock_logger.warning_once.call_args)
|
||||
assert "System messages should only contain text" in call_args
|
||||
assert "input_audio" in call_args
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"video_content",
|
||||
[
|
||||
[{"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}}],
|
||||
[{"video_url": {"url": "https://example.com/video.mp4"}}],
|
||||
],
|
||||
)
|
||||
def test_system_message_warns_on_video(video_content):
|
||||
"""Test that system messages with video content trigger a warning."""
|
||||
with patch(
|
||||
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
||||
) as mock_logger:
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": video_content,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
mock_logger.warning_once.assert_called()
|
||||
call_args = str(mock_logger.warning_once.call_args)
|
||||
assert "System messages should only contain text" in call_args
|
||||
assert "video_url" in call_args
|
||||
|
||||
|
||||
def test_json_schema_response_format_missing_schema():
|
||||
"""When response_format type is 'json_schema' but the json_schema field
|
||||
is not provided, request construction should raise a validation error
|
||||
so the API returns 400 instead of 500."""
|
||||
with pytest.raises(Exception, match="json_schema.*must be provided"):
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
response_format={"type": "json_schema"},
|
||||
)
|
||||
Reference in New Issue
Block a user