Add validation to reject non-text content in system messages (#34072)
Signed-off-by: Varun Chawla <varun_6april@hotmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -233,3 +233,140 @@ async def test_chat_error_stream():
|
||||
f"Expected error message in chunks: {chunks}"
|
||||
)
|
||||
assert chunks[-1] == "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"image_content",
|
||||
[
|
||||
[{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}],
|
||||
[{"image_url": {"url": "https://example.com/image.jpg"}}],
|
||||
],
|
||||
)
|
||||
def test_system_message_warns_on_image(image_content):
|
||||
"""Test that system messages with image content trigger a warning."""
|
||||
with patch(
|
||||
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
||||
) as mock_logger:
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": image_content,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
mock_logger.warning_once.assert_called()
|
||||
call_args = str(mock_logger.warning_once.call_args)
|
||||
assert "System messages should only contain text" in call_args
|
||||
assert "image_url" in call_args
|
||||
|
||||
|
||||
def test_system_message_accepts_text():
|
||||
"""Test that system messages can contain text content."""
|
||||
# Should not raise an exception
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
],
|
||||
)
|
||||
assert request.messages[0]["role"] == "system"
|
||||
|
||||
|
||||
def test_system_message_accepts_text_array():
|
||||
"""Test that system messages can contain an array with text content."""
|
||||
# Should not raise an exception
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert request.messages[0]["role"] == "system"
|
||||
|
||||
|
||||
def test_user_message_accepts_image():
|
||||
"""Test that user messages can still contain image content."""
|
||||
# Should not raise an exception
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/image.jpg"},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert request.messages[0]["role"] == "user"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"audio_content",
|
||||
[
|
||||
[
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": "base64data", "format": "wav"},
|
||||
}
|
||||
],
|
||||
[{"input_audio": {"data": "base64data", "format": "wav"}}],
|
||||
],
|
||||
)
|
||||
def test_system_message_warns_on_audio(audio_content):
|
||||
"""Test that system messages with audio content trigger a warning."""
|
||||
with patch(
|
||||
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
||||
) as mock_logger:
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": audio_content,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
mock_logger.warning_once.assert_called()
|
||||
call_args = str(mock_logger.warning_once.call_args)
|
||||
assert "System messages should only contain text" in call_args
|
||||
assert "input_audio" in call_args
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"video_content",
|
||||
[
|
||||
[{"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}}],
|
||||
[{"video_url": {"url": "https://example.com/video.mp4"}}],
|
||||
],
|
||||
)
|
||||
def test_system_message_warns_on_video(video_content):
|
||||
"""Test that system messages with video content trigger a warning."""
|
||||
with patch(
|
||||
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
||||
) as mock_logger:
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": video_content,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
mock_logger.warning_once.assert_called()
|
||||
call_args = str(mock_logger.warning_once.call_args)
|
||||
assert "System messages should only contain text" in call_args
|
||||
assert "video_url" in call_args
|
||||
|
||||
@@ -674,3 +674,52 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"Parameter 'cache_salt' must be a non-empty string if provided."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_system_message_content_type(cls, data):
|
||||
"""Warn if system messages contain non-text content.
|
||||
|
||||
According to OpenAI API spec, system messages can only be of type
|
||||
'text'. We log a warning instead of rejecting to avoid breaking
|
||||
users who intentionally send multimodal system messages.
|
||||
See: https://platform.openai.com/docs/api-reference/chat/create#chat_create-messages-system_message
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
messages = data.get("messages", [])
|
||||
for msg in messages:
|
||||
# Check if this is a system message
|
||||
if isinstance(msg, dict) and msg.get("role") == "system":
|
||||
content = msg.get("content")
|
||||
|
||||
# If content is a list (multimodal format)
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part_type = part.get("type")
|
||||
# Infer type when 'type' field is not explicit
|
||||
if part_type is None:
|
||||
if "image_url" in part or "image_pil" in part:
|
||||
part_type = "image_url"
|
||||
elif "image_embeds" in part:
|
||||
part_type = "image_embeds"
|
||||
elif "audio_url" in part:
|
||||
part_type = "audio_url"
|
||||
elif "input_audio" in part:
|
||||
part_type = "input_audio"
|
||||
elif "audio_embeds" in part:
|
||||
part_type = "audio_embeds"
|
||||
elif "video_url" in part:
|
||||
part_type = "video_url"
|
||||
|
||||
# Warn about non-text content in system messages
|
||||
if part_type and part_type != "text":
|
||||
logger.warning_once(
|
||||
"System messages should only contain text "
|
||||
"content according to the OpenAI API spec. "
|
||||
"Found content type: '%s'.",
|
||||
part_type,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user