[Refactor] Relocate chat completion and anthropic tests (#36919)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
1022
tests/entrypoints/openai/chat_completion/test_chat.py
Normal file
1022
tests/entrypoints/openai/chat_completion/test_chat.py
Normal file
File diff suppressed because it is too large
Load Diff
131
tests/entrypoints/openai/chat_completion/test_chat_echo.py
Normal file
131
tests/entrypoints/openai/chat_completion/test_chat_echo.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
# # any model with a chat template should work here
|
||||
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
|
||||
|
||||
|
||||
def get_vocab_size(model_name):
|
||||
config = ModelConfig(
|
||||
model=model_name,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
return config.get_vocab_size()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"4080",
|
||||
"--max-logprobs", # test prompt_logprobs equal to -1
|
||||
"151936",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
class TestCase(NamedTuple):
|
||||
model_name: str
|
||||
echo: bool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
TestCase(model_name=MODEL_NAME, echo=True),
|
||||
TestCase(model_name=MODEL_NAME, echo=False),
|
||||
],
|
||||
)
|
||||
async def test_chat_session_with_echo_and_continue_final_message(
|
||||
client: openai.AsyncOpenAI, test_case: TestCase
|
||||
):
|
||||
saying: str = "Here is a common saying about apple. An apple a day, keeps"
|
||||
# test echo with continue_final_message parameter
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=test_case.model_name,
|
||||
messages=[
|
||||
{"role": "user", "content": "tell me a common saying"},
|
||||
{"role": "assistant", "content": saying},
|
||||
],
|
||||
extra_body={
|
||||
"echo": test_case.echo,
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
},
|
||||
)
|
||||
assert chat_completion.id is not None
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "stop"
|
||||
|
||||
message = choice.message
|
||||
if test_case.echo:
|
||||
assert message.content is not None and saying in message.content
|
||||
else:
|
||||
assert message.content is not None and saying not in message.content
|
||||
assert message.role == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_logprobs(client: openai.AsyncOpenAI):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Beijing is the capital of which country?"},
|
||||
]
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
extra_body={"prompt_logprobs": -1},
|
||||
)
|
||||
|
||||
assert completion.prompt_logprobs is not None
|
||||
assert len(completion.prompt_logprobs) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_logprobs(client: openai.AsyncOpenAI):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Beijing is the capital of which country?"},
|
||||
]
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=1,
|
||||
extra_body={
|
||||
"top_logprobs": -1,
|
||||
"logprobs": "true",
|
||||
},
|
||||
)
|
||||
assert completion.choices[0].logprobs is not None
|
||||
assert completion.choices[0].logprobs.content is not None
|
||||
assert len(completion.choices[0].logprobs.content) > 0
|
||||
assert len(
|
||||
completion.choices[0].logprobs.content[0].top_logprobs
|
||||
) == get_vocab_size(MODEL_NAME)
|
||||
398
tests/entrypoints/openai/chat_completion/test_chat_error.py
Normal file
398
tests/entrypoints/openai/chat_completion/test_chat_error.py
Normal file
@@ -0,0 +1,398 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.engine.protocol import GenerationError
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
MODEL_NAME_SHORT = "gpt2"
|
||||
BASE_MODEL_PATHS = [
|
||||
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
||||
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHFConfig:
|
||||
model_type: str = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
task = "generate"
|
||||
runner_type = "generate"
|
||||
model = MODEL_NAME
|
||||
tokenizer = MODEL_NAME
|
||||
trust_remote_code = False
|
||||
tokenizer_mode = "auto"
|
||||
max_model_len = 100
|
||||
tokenizer_revision = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
hf_text_config = MockHFConfig()
|
||||
logits_processors: list[str] | None = None
|
||||
diff_sampling_param: dict | None = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: list[str] | None = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init = False
|
||||
is_encoder_decoder: bool = False
|
||||
is_multimodal_model: bool = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockParallelConfig:
|
||||
_api_process_rank: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockVllmConfig:
|
||||
model_config: MockModelConfig
|
||||
parallel_config: MockParallelConfig
|
||||
|
||||
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer.from_config(
|
||||
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
|
||||
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
)
|
||||
serving_render = OpenAIServingRender(
|
||||
model_config=engine.model_config,
|
||||
renderer=engine.renderer,
|
||||
io_processor=engine.io_processor,
|
||||
model_registry=models.registry,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
models,
|
||||
response_role="assistant",
|
||||
openai_serving_render=serving_render,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
|
||||
async def _fake_preprocess_chat(*args, **kwargs):
|
||||
# return conversation, engine_prompts
|
||||
return (
|
||||
[{"role": "user", "content": "Test"}],
|
||||
[{"prompt_token_ids": [1, 2, 3]}],
|
||||
)
|
||||
|
||||
serving_chat.openai_serving_render._preprocess_chat = AsyncMock(
|
||||
side_effect=_fake_preprocess_chat
|
||||
)
|
||||
return serving_chat
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_error_non_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text="",
|
||||
token_ids=[],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Test prompt"}],
|
||||
max_tokens=10,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
with pytest.raises(GenerationError):
|
||||
await serving_chat.create_chat_completion(request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_error_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
completion_output_1 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
request_output_1 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_1],
|
||||
finished=False,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
completion_output_2 = CompletionOutput(
|
||||
index=0,
|
||||
text="Hello",
|
||||
token_ids=[100],
|
||||
cumulative_logprob=None,
|
||||
logprobs=None,
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
request_output_2 = RequestOutput(
|
||||
request_id="test-id",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion_output_2],
|
||||
finished=True,
|
||||
metrics=None,
|
||||
lora_request=None,
|
||||
encoder_prompt=None,
|
||||
encoder_prompt_token_ids=None,
|
||||
)
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
yield request_output_1
|
||||
yield request_output_2
|
||||
|
||||
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Test prompt"}],
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response = await serving_chat.create_chat_completion(request)
|
||||
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
assert any("Internal server error" in chunk for chunk in chunks), (
|
||||
f"Expected error message in chunks: {chunks}"
|
||||
)
|
||||
assert chunks[-1] == "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"image_content",
|
||||
[
|
||||
[{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}],
|
||||
[{"image_url": {"url": "https://example.com/image.jpg"}}],
|
||||
],
|
||||
)
|
||||
def test_system_message_warns_on_image(image_content):
|
||||
"""Test that system messages with image content trigger a warning."""
|
||||
with patch(
|
||||
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
||||
) as mock_logger:
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": image_content,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
mock_logger.warning_once.assert_called()
|
||||
call_args = str(mock_logger.warning_once.call_args)
|
||||
assert "System messages should only contain text" in call_args
|
||||
assert "image_url" in call_args
|
||||
|
||||
|
||||
def test_system_message_accepts_text():
|
||||
"""Test that system messages can contain text content."""
|
||||
# Should not raise an exception
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
],
|
||||
)
|
||||
assert request.messages[0]["role"] == "system"
|
||||
|
||||
|
||||
def test_system_message_accepts_text_array():
|
||||
"""Test that system messages can contain an array with text content."""
|
||||
# Should not raise an exception
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert request.messages[0]["role"] == "system"
|
||||
|
||||
|
||||
def test_user_message_accepts_image():
|
||||
"""Test that user messages can still contain image content."""
|
||||
# Should not raise an exception
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/image.jpg"},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
assert request.messages[0]["role"] == "user"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"audio_content",
|
||||
[
|
||||
[
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": "base64data", "format": "wav"},
|
||||
}
|
||||
],
|
||||
[{"input_audio": {"data": "base64data", "format": "wav"}}],
|
||||
],
|
||||
)
|
||||
def test_system_message_warns_on_audio(audio_content):
|
||||
"""Test that system messages with audio content trigger a warning."""
|
||||
with patch(
|
||||
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
||||
) as mock_logger:
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": audio_content,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
mock_logger.warning_once.assert_called()
|
||||
call_args = str(mock_logger.warning_once.call_args)
|
||||
assert "System messages should only contain text" in call_args
|
||||
assert "input_audio" in call_args
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"video_content",
|
||||
[
|
||||
[{"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}}],
|
||||
[{"video_url": {"url": "https://example.com/video.mp4"}}],
|
||||
],
|
||||
)
|
||||
def test_system_message_warns_on_video(video_content):
|
||||
"""Test that system messages with video content trigger a warning."""
|
||||
with patch(
|
||||
"vllm.entrypoints.openai.chat_completion.protocol.logger"
|
||||
) as mock_logger:
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": video_content,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
mock_logger.warning_once.assert_called()
|
||||
call_args = str(mock_logger.warning_once.call_args)
|
||||
assert "System messages should only contain text" in call_args
|
||||
assert "video_url" in call_args
|
||||
|
||||
|
||||
def test_json_schema_response_format_missing_schema():
|
||||
"""When response_format type is 'json_schema' but the json_schema field
|
||||
is not provided, request construction should raise a validation error
|
||||
so the API returns 400 instead of 500."""
|
||||
with pytest.raises(Exception, match="json_schema.*must be provided"):
|
||||
ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
response_format={"type": "json_schema"},
|
||||
)
|
||||
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
|
||||
|
||||
def get_vocab_size(model_name):
|
||||
config = ModelConfig(
|
||||
model=model_name,
|
||||
seed=0,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
return config.get_vocab_size()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"1024",
|
||||
"--enforce-eager",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_logit_bias_valid(client):
|
||||
"""Test that valid logit_bias values are accepted in chat completions."""
|
||||
vocab_size = get_vocab_size(MODEL_NAME)
|
||||
valid_token_id = vocab_size - 1
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Testing valid logit bias"}],
|
||||
max_tokens=5,
|
||||
logit_bias={str(valid_token_id): 1.0},
|
||||
)
|
||||
|
||||
assert completion.choices[0].message.content is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_logit_bias_invalid(client):
|
||||
"""Test that invalid logit_bias values are rejected in chat completions."""
|
||||
vocab_size = get_vocab_size(MODEL_NAME)
|
||||
invalid_token_id = vocab_size + 1
|
||||
|
||||
with pytest.raises(openai.BadRequestError) as excinfo:
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "Testing invalid logit bias"}],
|
||||
max_tokens=5,
|
||||
logit_bias={str(invalid_token_id): 1.0},
|
||||
)
|
||||
|
||||
error = excinfo.value
|
||||
error_message = str(error)
|
||||
|
||||
assert error.status_code == 400
|
||||
assert str(invalid_token_id) in error_message
|
||||
assert str(vocab_size) in error_message
|
||||
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
# a reasoning and tool calling model
|
||||
MODEL_NAME = "Qwen/QwQ-32B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
"--reasoning-parser",
|
||||
"deepseek_r1",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for, e.g. "
|
||||
"'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "the two-letter abbreviation for the state that "
|
||||
"the city is in, e.g. 'CA' which would mean 'California'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
MESSAGES = [
|
||||
{"role": "user", "content": "Hi! How are you doing today?"},
|
||||
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you tell me what the temperate will be in Dallas, "
|
||||
"in fahrenheit?",
|
||||
},
|
||||
]
|
||||
|
||||
FUNC_NAME = "get_current_weather"
|
||||
FUNC_ARGS = """{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}"""
|
||||
|
||||
|
||||
def extract_reasoning_and_calls(chunks: list):
|
||||
reasoning = ""
|
||||
tool_call_idx = -1
|
||||
arguments = []
|
||||
function_names = []
|
||||
for chunk in chunks:
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.index != tool_call_idx:
|
||||
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
|
||||
arguments.append("")
|
||||
function_names.append("")
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
function_names[tool_call_idx] = tool_call.function.name
|
||||
|
||||
if tool_call.function.arguments:
|
||||
arguments[tool_call_idx] += tool_call.function.arguments
|
||||
else:
|
||||
if hasattr(chunk.choices[0].delta, "reasoning"):
|
||||
reasoning += chunk.choices[0].delta.reasoning
|
||||
return reasoning, arguments, function_names
|
||||
|
||||
|
||||
# test streaming
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_streaming_of_tool_and_reasoning(client: openai.AsyncOpenAI):
|
||||
stream = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
|
||||
assert len(reasoning) > 0
|
||||
assert len(function_names) > 0 and function_names[0] == FUNC_NAME
|
||||
assert len(arguments) > 0 and arguments[0] == FUNC_ARGS
|
||||
|
||||
|
||||
# test full generate
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI):
|
||||
tool_calls = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert len(tool_calls.choices[0].message.reasoning) > 0
|
||||
assert tool_calls.choices[0].message.tool_calls[0].function.name == FUNC_NAME
|
||||
assert tool_calls.choices[0].message.tool_calls[0].function.arguments == FUNC_ARGS
|
||||
@@ -0,0 +1,540 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import datetime
|
||||
import json
|
||||
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
# downloading lora to test lora requests
|
||||
from tests.utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for, e.g. "
|
||||
"'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type": "string",
|
||||
"description": "The country that the city is in, e.g. "
|
||||
"'Austria'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
"options": {
|
||||
"$ref": "#/$defs/WeatherOptions",
|
||||
"description": "Optional parameters for weather query",
|
||||
},
|
||||
},
|
||||
"required": ["country", "unit"],
|
||||
"$defs": {
|
||||
"WeatherOptions": {
|
||||
"title": "WeatherOptions",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
"description": "Temperature unit",
|
||||
"title": "Temperature Unit",
|
||||
},
|
||||
"include_forecast": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description": "Whether to include a 24-hour forecast",
|
||||
"title": "Include Forecast",
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"default": "zh-CN",
|
||||
"description": "Language of the response",
|
||||
"title": "Language",
|
||||
"enum": ["zh-CN", "en-US", "ja-JP"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the forecast for, e.g. "
|
||||
"'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type": "string",
|
||||
"description": "The country that the city is in, e.g. "
|
||||
"'Austria'",
|
||||
},
|
||||
"days": {
|
||||
"type": "integer",
|
||||
"description": "Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "days", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi! How are you doing today?"},
|
||||
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you tell me what the current weather is in Berlin and the "
|
||||
"forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"half",
|
||||
"--enable-auto-tool-choice",
|
||||
"--structured-outputs-config.backend",
|
||||
"xgrammar",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
"--gpu-memory-utilization",
|
||||
"0.4",
|
||||
"--enforce-eager",
|
||||
] + ROCM_EXTRA_ARGS
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME, args, env_dict=ROCM_ENV_OVERRIDES
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"tool_choice",
|
||||
[
|
||||
"auto",
|
||||
"required",
|
||||
{"type": "function", "function": {"name": "get_current_weather"}},
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_thinking", [True, False])
|
||||
async def test_function_tool_use(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
stream: bool,
|
||||
tool_choice: str | dict,
|
||||
enable_thinking: bool,
|
||||
):
|
||||
if not stream:
|
||||
# Non-streaming test
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}},
|
||||
)
|
||||
if enable_thinking:
|
||||
assert chat_completion.choices[0].message.reasoning is not None
|
||||
assert chat_completion.choices[0].message.reasoning != ""
|
||||
assert chat_completion.choices[0].message.tool_calls is not None
|
||||
assert len(chat_completion.choices[0].message.tool_calls) > 0
|
||||
else:
|
||||
# Streaming test
|
||||
output_stream = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}},
|
||||
)
|
||||
|
||||
output = []
|
||||
reasoning = []
|
||||
async for chunk in output_stream:
|
||||
if chunk.choices:
|
||||
if enable_thinking and getattr(
|
||||
chunk.choices[0].delta, "reasoning", None
|
||||
):
|
||||
reasoning.append(chunk.choices[0].delta.reasoning)
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
output.extend(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
assert len(output) > 0
|
||||
if enable_thinking:
|
||||
assert len(reasoning) > 0
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def k2_server():
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"half",
|
||||
"--enable-auto-tool-choice",
|
||||
"--structured-outputs-config.backend",
|
||||
"xgrammar",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
"--gpu-memory-utilization",
|
||||
"0.4",
|
||||
] + ROCM_EXTRA_ARGS
|
||||
# hack to test kimi_k2 tool use tool_id format.
|
||||
# avoid error in is_deepseek_mla check by setting kv_lora_rank=null
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME,
|
||||
args,
|
||||
env_dict=ROCM_ENV_OVERRIDES,
|
||||
override_hf_configs={"model_type": "kimi_k2", "kv_lora_rank": None},
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def k2_client(k2_server):
|
||||
async with k2_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.parametrize("tool_choice", ["required"])
|
||||
async def test_tool_id_kimi_k2(
|
||||
k2_client: openai.AsyncOpenAI, model_name: str, stream: bool, tool_choice: str
|
||||
):
|
||||
if not stream:
|
||||
# Non-streaming test
|
||||
chat_completion = await k2_client.chat.completions.create(
|
||||
messages=messages, model=model_name, tools=tools, tool_choice=tool_choice
|
||||
)
|
||||
assert chat_completion.choices[0].message.tool_calls is not None
|
||||
assert len(chat_completion.choices[0].message.tool_calls) > 0
|
||||
assert chat_completion.choices[0].message.tool_calls[0].id in [
|
||||
"functions.get_current_weather:0",
|
||||
"functions.get_forecast:1",
|
||||
]
|
||||
else:
|
||||
# Streaming test
|
||||
output_stream = await k2_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
output = []
|
||||
async for chunk in output_stream:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
output.extend(chunk.choices[0].delta.tool_calls)
|
||||
for o in output:
|
||||
assert o.id is None or o.id in [
|
||||
"functions.get_current_weather:0",
|
||||
"functions.get_forecast:1",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("arguments", ["{}", ""])
|
||||
async def test_no_args_tool_call(
|
||||
client: openai.AsyncOpenAI, model_name: str, arguments: str
|
||||
):
|
||||
# Step 1: Define a tool that requires no parameters
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_time",
|
||||
"description": (
|
||||
"Get the current date and time. Call this when the user "
|
||||
"asks what time or date it is. No parameters needed."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}, # No parameters
|
||||
"required": [], # No required fields
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a helpful assistant. Always use the available tools "
|
||||
"when relevant, and reply with a short sentence after "
|
||||
"receiving a tool result."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "What time is it now?"},
|
||||
]
|
||||
|
||||
shared_kwargs = dict(
|
||||
model=model_name,
|
||||
temperature=0.0,
|
||||
seed=42,
|
||||
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
||||
)
|
||||
|
||||
# Step 2: Send user message and let model decide whether to call the tool
|
||||
response = await client.chat.completions.create(
|
||||
**shared_kwargs,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto", # Let model choose automatically
|
||||
)
|
||||
|
||||
# Step 3: Check if model wants to call a tool
|
||||
message = response.choices[0].message
|
||||
if message.tool_calls:
|
||||
# Get the first tool call
|
||||
tool_call = message.tool_calls[0]
|
||||
tool_name = tool_call.function.name
|
||||
# Step 4: Execute the tool locally (no parameters)
|
||||
if tool_name == "get_current_time":
|
||||
# Test both empty string and "{}" for no-arg tool calls
|
||||
tool_call.function.arguments = arguments
|
||||
messages.append(message)
|
||||
current_time = datetime.datetime.now()
|
||||
result = current_time.isoformat()
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": result,
|
||||
}
|
||||
)
|
||||
# Step 5: Send tool result back to model to continue conversation
|
||||
final_response = await client.chat.completions.create(
|
||||
**shared_kwargs,
|
||||
messages=messages,
|
||||
max_completion_tokens=128,
|
||||
)
|
||||
# Output final natural language response
|
||||
assert (
|
||||
final_response.choices[0].message.content is not None
|
||||
and final_response.choices[0].message.content.strip() != ""
|
||||
)
|
||||
|
||||
else:
|
||||
# No tool called — just print model's direct reply
|
||||
assert message.content is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_named_tool_use(
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_json_schema,
|
||||
):
|
||||
messages = [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Give an example JSON for an employee profile using the specified tool."
|
||||
),
|
||||
},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema,
|
||||
},
|
||||
}
|
||||
]
|
||||
tool_choice = {"type": "function", "function": {"name": "dummy_function_name"}}
|
||||
|
||||
# non-streaming
|
||||
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
message = chat_completion.choices[0].message
|
||||
assert len(message.content) == 0
|
||||
json_string = message.tool_calls[0].function.arguments
|
||||
json1 = json.loads(json_string)
|
||||
jsonschema.validate(instance=json1, schema=sample_json_schema)
|
||||
|
||||
messages.append({"role": "assistant", "content": json_string})
|
||||
messages.append(
|
||||
{"role": "user", "content": "Give me another one with a different name and age"}
|
||||
)
|
||||
|
||||
# streaming
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
output = []
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.role:
|
||||
assert delta.role == "assistant"
|
||||
assert delta.content is None or len(delta.content) == 0
|
||||
if delta.tool_calls:
|
||||
output.append(delta.tool_calls[0].function.arguments)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
# finish reason should only return in last block
|
||||
assert finish_reason_count == 1
|
||||
json2 = json.loads("".join(output))
|
||||
jsonschema.validate(instance=json2, schema=sample_json_schema)
|
||||
assert json1["name"] != json2["name"]
|
||||
assert json1["age"] != json2["age"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inconsistent_tool_choice_and_tools(
|
||||
client: openai.AsyncOpenAI, sample_json_schema
|
||||
):
|
||||
messages = [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {sample_json_schema}",
|
||||
},
|
||||
]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {"name": "dummy_function_name"},
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema,
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {"name": "nondefined_function_name"},
|
||||
},
|
||||
)
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema,
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_tokens_with_tool_choice_required(client: openai.AsyncOpenAI):
|
||||
""" """
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
|
||||
# This combination previously crashed the engine
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_completion_tokens=1,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
)
|
||||
# When `tool_choice="required"` and the tokens of `tools` exceed `max_tokens`,
|
||||
# both `tool_calls` and `content` should be empty.
|
||||
# This behavior should be consistent with OpenAI.
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert len(choice.message.tool_calls) == 0
|
||||
assert choice.message.content == ""
|
||||
@@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def chat_server_with_force_include_usage(request):
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
"--max-num-seqs",
|
||||
"4",
|
||||
"--enable-force-include-usage",
|
||||
"--port",
|
||||
"55857",
|
||||
"--gpu-memory-utilization",
|
||||
"0.2",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer("Qwen/Qwen3-0.6B", args, auto_port=False) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def chat_client_with_force_include_usage(chat_server_with_force_include_usage):
|
||||
async with chat_server_with_force_include_usage.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_enable_force_include_usage(
|
||||
chat_client_with_force_include_usage: openai.AsyncOpenAI,
|
||||
):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
]
|
||||
|
||||
stream = await chat_client_with_force_include_usage.chat.completions.create(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
extra_body=dict(min_tokens=10),
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
last_completion_tokens = 0
|
||||
async for chunk in stream:
|
||||
if not len(chunk.choices):
|
||||
assert chunk.usage.prompt_tokens >= 0
|
||||
assert (
|
||||
last_completion_tokens == 0
|
||||
or chunk.usage.completion_tokens > last_completion_tokens
|
||||
or (
|
||||
not chunk.choices
|
||||
and chunk.usage.completion_tokens == last_completion_tokens
|
||||
)
|
||||
)
|
||||
assert chunk.usage.total_tokens == (
|
||||
chunk.usage.prompt_tokens + chunk.usage.completion_tokens
|
||||
)
|
||||
else:
|
||||
assert chunk.usage is None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def transcription_server_with_force_include_usage():
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-num-seqs",
|
||||
"4",
|
||||
"--enforce-eager",
|
||||
"--enable-force-include-usage",
|
||||
"--gpu-memory-utilization",
|
||||
"0.2",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer("openai/whisper-large-v3-turbo", args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def transcription_client_with_force_include_usage(
|
||||
transcription_server_with_force_include_usage,
|
||||
):
|
||||
async with (
|
||||
transcription_server_with_force_include_usage.get_async_client() as async_client
|
||||
):
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcription_with_enable_force_include_usage(
|
||||
transcription_client_with_force_include_usage, winning_call
|
||||
):
|
||||
res = (
|
||||
await transcription_client_with_force_include_usage.audio.transcriptions.create(
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
file=winning_call,
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
timeout=30,
|
||||
)
|
||||
)
|
||||
|
||||
async for chunk in res:
|
||||
if not len(chunk.choices):
|
||||
# final usage sent
|
||||
usage = chunk.usage
|
||||
assert isinstance(usage, dict)
|
||||
assert usage["prompt_tokens"] > 0
|
||||
assert usage["completion_tokens"] > 0
|
||||
assert usage["total_tokens"] > 0
|
||||
else:
|
||||
assert not hasattr(chunk, "usage")
|
||||
1932
tests/entrypoints/openai/chat_completion/test_serving_chat.py
Normal file
1932
tests/entrypoints/openai/chat_completion/test_serving_chat.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,350 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for harmony streaming delta extraction.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.stream_harmony import (
|
||||
TokenState,
|
||||
extract_harmony_streaming_delta,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockMessage:
|
||||
"""Mock message object for testing."""
|
||||
|
||||
channel: str | None = None
|
||||
recipient: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockStreamableParser:
|
||||
"""Mock StreamableParser for testing without openai_harmony dependency."""
|
||||
|
||||
messages: list[MockMessage] = field(default_factory=list)
|
||||
|
||||
|
||||
class TestExtractHarmonyStreamingDelta:
|
||||
"""Tests for extract_harmony_streaming_delta function."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"delta_text,expected_content",
|
||||
[
|
||||
("Hello, world!", "Hello, world!"),
|
||||
("", ""),
|
||||
],
|
||||
)
|
||||
def test_final_channel_returns_content_delta(self, delta_text, expected_content):
|
||||
"""Test that final channel returns a DeltaMessage with content."""
|
||||
parser = MockStreamableParser()
|
||||
|
||||
# Updated to use TokenState list
|
||||
token_states = [TokenState(channel="final", recipient=None, text=delta_text)]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
assert delta_message is not None
|
||||
assert delta_message.content == expected_content
|
||||
assert tools_streamed is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"include_reasoning,expected_has_message",
|
||||
[
|
||||
(True, True),
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
def test_analysis_channel_reasoning(self, include_reasoning, expected_has_message):
|
||||
"""Test analysis channel respects include_reasoning flag."""
|
||||
parser = MockStreamableParser()
|
||||
text = "Let me think..."
|
||||
token_states = [TokenState(channel="analysis", recipient=None, text=text)]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
include_reasoning=include_reasoning,
|
||||
)
|
||||
|
||||
if expected_has_message:
|
||||
assert delta_message is not None
|
||||
assert delta_message.reasoning == text
|
||||
else:
|
||||
assert delta_message is None
|
||||
assert tools_streamed is False
|
||||
|
||||
@pytest.mark.parametrize("channel", ["commentary", "analysis"])
|
||||
@patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id")
|
||||
def test_new_tool_call(self, mock_make_tool_call_id, channel):
|
||||
"""Test new tool call creation when recipient changes."""
|
||||
mock_make_tool_call_id.return_value = "call_test123"
|
||||
parser = MockStreamableParser()
|
||||
|
||||
token_states = [
|
||||
TokenState(channel=channel, recipient="functions.get_weather", text="")
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
assert delta_message is not None
|
||||
assert len(delta_message.tool_calls) == 1
|
||||
tool_call = delta_message.tool_calls[0]
|
||||
assert tool_call.id == "call_test123"
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_weather"
|
||||
assert tool_call.function.arguments == ""
|
||||
assert tool_call.index == 0
|
||||
assert tools_streamed is True
|
||||
|
||||
@pytest.mark.parametrize("channel", ["commentary", "analysis"])
|
||||
def test_tool_call_argument_streaming(self, channel):
|
||||
"""Test streaming tool call arguments (same recipient)."""
|
||||
parser = MockStreamableParser()
|
||||
args_text = '{"location": "Paris"}'
|
||||
|
||||
token_states = [
|
||||
TokenState(
|
||||
channel=channel, recipient="functions.get_weather", text=args_text
|
||||
)
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient="functions.get_weather",
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
assert delta_message is not None
|
||||
tool_call = delta_message.tool_calls[0]
|
||||
assert tool_call.id is None
|
||||
assert tool_call.function.arguments == args_text
|
||||
assert tool_call.index == 0
|
||||
assert tools_streamed is True
|
||||
|
||||
@pytest.mark.parametrize("channel", ["commentary", "analysis"])
|
||||
def test_tool_call_empty_arguments_returns_none(self, channel):
|
||||
"""Test empty delta_text with same recipient returns None."""
|
||||
parser = MockStreamableParser()
|
||||
|
||||
token_states = [
|
||||
TokenState(channel=channel, recipient="functions.get_weather", text="")
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient="functions.get_weather",
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
assert delta_message is None
|
||||
assert tools_streamed is False
|
||||
|
||||
def test_tool_call_index_from_previous_messages(self):
|
||||
"""Test tool call index accounts for previous function messages."""
|
||||
messages = [
|
||||
MockMessage(channel="analysis", recipient=None), # Not counted
|
||||
MockMessage(channel="commentary", recipient="functions.tool1"), # Counted
|
||||
MockMessage(channel="final", recipient=None), # Not counted
|
||||
]
|
||||
parser = MockStreamableParser(messages=messages)
|
||||
|
||||
token_states = [
|
||||
TokenState(channel="commentary", recipient="functions.tool2", text="args")
|
||||
]
|
||||
|
||||
delta_message, _ = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient="functions.tool2",
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
assert delta_message.tool_calls[0].index == 1
|
||||
|
||||
def test_returns_preambles_as_content(self):
|
||||
"""Test that commentary with no recipient (preamble) is user content."""
|
||||
parser = MockStreamableParser()
|
||||
delta_text = "some text"
|
||||
|
||||
token_states = [
|
||||
TokenState(channel="commentary", recipient=None, text=delta_text)
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
include_reasoning=True,
|
||||
)
|
||||
|
||||
assert delta_message.content == delta_text
|
||||
assert tools_streamed is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"channel,recipient",
|
||||
[
|
||||
(None, None),
|
||||
("unknown_channel", None),
|
||||
("commentary", "browser.search"),
|
||||
],
|
||||
)
|
||||
def test_returns_none_for_invalid_inputs(self, channel, recipient):
|
||||
"""Test that invalid channel/recipient combinations return None."""
|
||||
parser = MockStreamableParser()
|
||||
|
||||
token_states = [
|
||||
TokenState(channel=channel, recipient=recipient, text="some text")
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
include_reasoning=True,
|
||||
)
|
||||
|
||||
assert delta_message is None
|
||||
assert tools_streamed is False
|
||||
|
||||
def test_consecutive_token_grouping(self):
|
||||
"""
|
||||
Test that consecutive tokens with the same channel/recipient
|
||||
are merged into a single processing group.
|
||||
"""
|
||||
parser = MockStreamableParser()
|
||||
token_states = [
|
||||
TokenState("final", None, "H"),
|
||||
TokenState("final", None, "el"),
|
||||
TokenState("final", None, "lo"),
|
||||
TokenState("final", None, ","),
|
||||
TokenState("final", None, " World"),
|
||||
]
|
||||
|
||||
delta_message, _ = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
assert delta_message is not None
|
||||
assert delta_message.content == "Hello, World"
|
||||
|
||||
@patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id")
|
||||
def test_complex_batch_permutation(self, mock_make_id):
|
||||
"""
|
||||
Test a complex permutation: Reasoning -> Tool Call -> Content.
|
||||
This verifies that multiple distinct actions in one batch
|
||||
are all captured in the single DeltaMessage.
|
||||
"""
|
||||
mock_make_id.return_value = "call_batch_test"
|
||||
parser = MockStreamableParser()
|
||||
|
||||
token_states = [
|
||||
# 1. Reasoning
|
||||
TokenState("analysis", None, "Reasoning about query..."),
|
||||
# 2. Tool Calling
|
||||
TokenState("commentary", "functions.search", '{"query":'),
|
||||
TokenState("commentary", "functions.search", ' "vllm"}'),
|
||||
# 3. Final Content
|
||||
TokenState("final", None, "."),
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
include_reasoning=True,
|
||||
)
|
||||
|
||||
assert delta_message is not None
|
||||
|
||||
assert delta_message.reasoning == "Reasoning about query..."
|
||||
|
||||
# We expect 2 objects for 1 logical tool call:
|
||||
# 1. The definition (id, name, type)
|
||||
# 2. The arguments payload
|
||||
assert len(delta_message.tool_calls) == 2
|
||||
|
||||
header = delta_message.tool_calls[0]
|
||||
payload = delta_message.tool_calls[1]
|
||||
|
||||
assert header.function.name == "search"
|
||||
assert header.id == "call_batch_test"
|
||||
assert header.index == 0
|
||||
|
||||
assert payload.index == 0
|
||||
assert payload.function.arguments == '{"query": "vllm"}'
|
||||
|
||||
assert delta_message.content == "."
|
||||
assert tools_streamed is True
|
||||
|
||||
@patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id")
|
||||
def test_tool_call_index_consistency_with_ongoing_call(self, mock_make_id):
|
||||
"""
|
||||
Test that an ongoing tool call continuation and subsequent new calls
|
||||
maintain correct indexing when interleaved with content.
|
||||
"""
|
||||
mock_make_id.side_effect = ["id_b", "id_c"]
|
||||
|
||||
messages = [
|
||||
MockMessage(channel="commentary", recipient="functions.previous_tool")
|
||||
]
|
||||
parser = MockStreamableParser(messages=messages)
|
||||
|
||||
token_states = [
|
||||
TokenState("commentary", "functions.tool_a", '{"key_a": "val_a"}'),
|
||||
TokenState("final", None, "Thinking..."),
|
||||
TokenState("commentary", "functions.tool_b", '{"key_b": "val_b"}'),
|
||||
TokenState("final", None, " Thinking again..."),
|
||||
TokenState("commentary", "functions.tool_c", '{"key_c": "val_c"}'),
|
||||
]
|
||||
|
||||
delta_message, _ = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient="functions.tool_a",
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
assert delta_message is not None
|
||||
|
||||
tool_a_deltas = [t for t in delta_message.tool_calls if t.index == 1]
|
||||
assert len(tool_a_deltas) > 0
|
||||
assert tool_a_deltas[0].id is None
|
||||
assert tool_a_deltas[0].function.arguments == '{"key_a": "val_a"}'
|
||||
|
||||
tool_b_header = next(t for t in delta_message.tool_calls if t.id == "id_b")
|
||||
assert tool_b_header.index == 2
|
||||
tool_b_args = next(
|
||||
t for t in delta_message.tool_calls if t.index == 2 and t.id is None
|
||||
)
|
||||
assert tool_b_args.function.arguments == '{"key_b": "val_b"}'
|
||||
|
||||
tool_c_start = next(t for t in delta_message.tool_calls if t.id == "id_c")
|
||||
assert tool_c_start.index == 3
|
||||
tool_c_args = next(
|
||||
t for t in delta_message.tool_calls if t.index == 3 and t.id is None
|
||||
)
|
||||
assert tool_c_args.function.arguments == '{"key_c": "val_c"}'
|
||||
|
||||
assert delta_message.content == "Thinking... Thinking again..."
|
||||
Reference in New Issue
Block a user