Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -13,13 +13,13 @@ from .utils import ARGS, CONFIGS, ServerConfig
# select models to test based on command line arguments
def pytest_addoption(parser):
parser.addoption("--models",
nargs="+",
help="Specify one or more models to test")
parser.addoption("--extended",
action="store_true",
default=False,
help="invoke extended tests requiring large GPUs")
parser.addoption("--models", nargs="+", help="Specify one or more models to test")
parser.addoption(
"--extended",
action="store_true",
default=False,
help="invoke extended tests requiring large GPUs",
)
# for each server config, download the model and return the config
@@ -29,8 +29,10 @@ def server_config(request):
models = request.config.getoption("--models")
config_keys_to_test = [
key for key in CONFIGS if (models is None or key in models) and (
extended or not CONFIGS[key].get("extended", False))
key
for key in CONFIGS
if (models is None or key in models)
and (extended or not CONFIGS[key].get("extended", False))
]
config_key = request.param
@@ -40,8 +42,9 @@ def server_config(request):
config = CONFIGS[config_key]
if current_platform.is_rocm() and not config.get("supports_rocm", True):
pytest.skip("The {} model can't be tested on the ROCm platform".format(
config["model"]))
pytest.skip(
"The {} model can't be tested on the ROCm platform".format(config["model"])
)
# download model and tokenizer using transformers
snapshot_download(config["model"])
@@ -53,8 +56,9 @@ def server_config(request):
def server(request, server_config: ServerConfig):
model = server_config["model"]
args_for_model = server_config["arguments"]
with RemoteOpenAIServer(model, ARGS + args_for_model,
max_wait_seconds=480) as server:
with RemoteOpenAIServer(
model, ARGS + args_for_model, max_wait_seconds=480
) as server:
yield server

View File

@@ -17,8 +17,9 @@ def server_config(request):
config = CONFIGS[request.param]
if current_platform.is_rocm() and not config.get("supports_rocm", True):
pytest.skip("The {} model can't be tested on the ROCm platform".format(
config["model"]))
pytest.skip(
"The {} model can't be tested on the ROCm platform".format(config["model"])
)
# download model and tokenizer using transformers
snapshot_download(config["model"])
@@ -30,8 +31,9 @@ def server_config(request):
def server(request, server_config: ServerConfig):
model = server_config["model"]
args_for_model = server_config["arguments"]
with RemoteOpenAIServer(model, ARGS + args_for_model,
max_wait_seconds=480) as server:
with RemoteOpenAIServer(
model, ARGS + args_for_model, max_wait_seconds=480
) as server:
yield server

View File

@@ -19,12 +19,12 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
model=model_name,
tools=[WEATHER_TOOL],
tool_choice=WEATHER_TOOL,
logprobs=False)
logprobs=False,
)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 1
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral

View File

@@ -18,17 +18,16 @@ ARGS: list[str] = ["--max-model-len", "1024"]
CONFIGS: dict[str, ServerConfig] = {
"mistral": {
"model":
"mistralai/Mistral-7B-Instruct-v0.3",
"model": "mistralai/Mistral-7B-Instruct-v0.3",
"arguments": [
"--tokenizer-mode", "mistral",
"--ignore-patterns=\"consolidated.safetensors\""
"--tokenizer-mode",
"mistral",
'--ignore-patterns="consolidated.safetensors"',
],
"system_prompt":
"You are a helpful assistant with access to tools. If a tool"
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
"to the user's question - just respond to it normally.",
},
}

View File

@@ -8,68 +8,56 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest
def test_chat_completion_request_with_no_tools():
# tools key is not present
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
'facebook/opt-125m',
})
assert request.tool_choice == 'none'
request = ChatCompletionRequest.model_validate(
{
"messages": [{"role": "user", "content": "Hello"}],
"model": "facebook/opt-125m",
}
)
assert request.tool_choice == "none"
# tools key is None
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
'facebook/opt-125m',
'tools':
None
})
assert request.tool_choice == 'none'
request = ChatCompletionRequest.model_validate(
{
"messages": [{"role": "user", "content": "Hello"}],
"model": "facebook/opt-125m",
"tools": None,
}
)
assert request.tool_choice == "none"
# tools key present but empty
request = ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
'facebook/opt-125m',
'tools': []
})
assert request.tool_choice == 'none'
request = ChatCompletionRequest.model_validate(
{
"messages": [{"role": "user", "content": "Hello"}],
"model": "facebook/opt-125m",
"tools": [],
}
)
assert request.tool_choice == "none"
@pytest.mark.parametrize('tool_choice', ['auto', 'required'])
@pytest.mark.parametrize("tool_choice", ["auto", "required"])
def test_chat_completion_request_with_tool_choice_but_no_tools(tool_choice):
with pytest.raises(ValueError,
match="When using `tool_choice`, `tools` must be set."):
ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
'facebook/opt-125m',
'tool_choice':
tool_choice
})
with pytest.raises(
ValueError, match="When using `tool_choice`, `tools` must be set."
):
ChatCompletionRequest.model_validate(
{
"messages": [{"role": "user", "content": "Hello"}],
"model": "facebook/opt-125m",
"tool_choice": tool_choice,
}
)
with pytest.raises(ValueError,
match="When using `tool_choice`, `tools` must be set."):
ChatCompletionRequest.model_validate({
'messages': [{
'role': 'user',
'content': 'Hello'
}],
'model':
'facebook/opt-125m',
'tool_choice':
tool_choice,
'tools':
None
})
with pytest.raises(
ValueError, match="When using `tool_choice`, `tools` must be set."
):
ChatCompletionRequest.model_validate(
{
"messages": [{"role": "user", "content": "Hello"}],
"model": "facebook/opt-125m",
"tool_choice": tool_choice,
"tools": None,
}
)

View File

@@ -4,16 +4,21 @@
import openai
import pytest
from .utils import (MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL, ServerConfig,
ensure_system_prompt)
from .utils import (
MESSAGES_WITHOUT_TOOLS,
WEATHER_TOOL,
ServerConfig,
ensure_system_prompt,
)
# test: make sure chat completions without tools provided work even when tools
# are enabled. This makes sure tool call chat templates work, AND that the tool
# parser stream processing doesn't change the output of the model.
@pytest.mark.asyncio
async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
server_config: ServerConfig):
async def test_chat_completion_without_tools(
client: openai.AsyncOpenAI, server_config: ServerConfig
):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
@@ -21,7 +26,8 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
temperature=0,
max_completion_tokens=150,
model=model_name,
logprobs=False)
logprobs=False,
)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
output_text = chat_completion.choices[0].message.content
@@ -32,8 +38,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
assert stop_reason != "tool_calls"
# check to make sure no tool calls were returned
assert (choice.message.tool_calls is None
or len(choice.message.tool_calls) == 0)
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
# make the same request, streaming
stream = await client.chat.completions.create(
@@ -55,7 +60,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
# make sure the role is assistant
if delta.role:
assert not role_sent
assert delta.role == 'assistant'
assert delta.role == "assistant"
role_sent = True
if delta.content:
@@ -80,8 +85,9 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
# tools, to make sure we can still get normal chat completion responses
# and that they won't be parsed as tools
@pytest.mark.asyncio
async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
server_config: ServerConfig):
async def test_chat_completion_with_tools(
client: openai.AsyncOpenAI, server_config: ServerConfig
):
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
@@ -90,19 +96,19 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
max_completion_tokens=150,
model=model_name,
tools=[WEATHER_TOOL],
logprobs=False)
logprobs=False,
)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
output_text = chat_completion.choices[0].message.content
# check to make sure we got text
assert output_text is not None
assert stop_reason != 'tool_calls'
assert stop_reason != "tool_calls"
assert len(output_text) > 0
# check to make sure no tool calls were returned
assert (choice.message.tool_calls is None
or len(choice.message.tool_calls) == 0)
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
# make the same request, streaming
stream = await client.chat.completions.create(
@@ -125,7 +131,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
# make sure the role is assistant
if delta.role:
assert delta.role == 'assistant'
assert delta.role == "assistant"
role_sent = True
if delta.content:
@@ -142,6 +148,6 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
assert role_sent
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == stop_reason
assert chunk.choices[0].finish_reason != 'tool_calls'
assert chunk.choices[0].finish_reason != "tool_calls"
assert len(chunks)
assert "".join(chunks) == output_text

View File

@@ -21,23 +21,28 @@ def parser(deepseekv31_tokenizer):
def test_extract_tool_calls_with_tool(parser):
model_output = (
"normal text" + "<tool▁calls▁begin>" +
"<tool▁call▁begin>foo<tool▁sep>{\"x\":1}<tool▁call▁end>" +
"<tool▁calls▁end>")
"normal text"
+ "<tool▁calls▁begin>"
+ '<tool▁call▁begin>foo<tool▁sep>{"x":1}<tool▁call▁end>'
+ "<tool▁calls▁end>"
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "foo"
assert result.tool_calls[0].function.arguments == "{\"x\":1}"
assert result.tool_calls[0].function.arguments == '{"x":1}'
assert result.content == "normal text"
def test_extract_tool_calls_with_multiple_tools(parser):
model_output = (
"some prefix text" + "<tool▁calls▁begin>" +
"<tool▁call▁begin>foo<tool▁sep>{\"x\":1}<tool▁call▁end>" +
"<tool▁call▁begin>bar<tool▁sep>{\"y\":2}<tool▁call▁end>" +
"<tool▁calls▁end>" + " some suffix text")
"some prefix text"
+ "<tool▁calls▁begin>"
+ '<tool▁call▁begin>foo<tool▁sep>{"x":1}<tool▁call▁end>'
+ '<tool▁call▁begin>bar<tool▁sep>{"y":2}<tool▁call▁end>'
+ "<tool▁calls▁end>"
+ " some suffix text"
)
result = parser.extract_tool_calls(model_output, None)
@@ -45,10 +50,10 @@ def test_extract_tool_calls_with_multiple_tools(parser):
assert len(result.tool_calls) == 2
assert result.tool_calls[0].function.name == "foo"
assert result.tool_calls[0].function.arguments == "{\"x\":1}"
assert result.tool_calls[0].function.arguments == '{"x":1}'
assert result.tool_calls[1].function.name == "bar"
assert result.tool_calls[1].function.arguments == "{\"y\":2}"
assert result.tool_calls[1].function.arguments == '{"y":2}'
# prefix is content
assert result.content == "some prefix text"

View File

@@ -27,12 +27,14 @@ def glm4_moe_tool_parser(glm4_moe_tokenizer):
return Glm4MoeModelToolParser(glm4_moe_tokenizer)
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 0
@@ -47,7 +49,8 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
model_output = "This is a test"
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@@ -73,14 +76,18 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
<arg_value>fahrenheit</arg_value>
</tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
None,
),
@@ -102,22 +109,30 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
<arg_value>fahrenheit</arg_value>
</tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}
),
)
),
],
None,
),
@@ -131,14 +146,18 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
<arg_value>celsius</arg_value>
</tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Seattle",
"state": "WA",
"unit": "celsius",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Seattle",
"state": "WA",
"unit": "celsius",
}
),
)
)
],
"I'll help you check the weather.",
),
@@ -152,37 +171,51 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
<arg_value>celsius</arg_value>
</tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "New York",
"state": "NY",
"unit": "celsius",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "New York",
"state": "NY",
"unit": "celsius",
}
),
)
)
],
None,
),
("""I will help you get the weather.<tool_call>get_weather
(
"""I will help you get the weather.<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>
</tool_call>""", [
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"city": "Beijing",
"date": "2025-08-01",
}),
))
], "I will help you get the weather."),
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Beijing",
"date": "2025-08-01",
}
),
)
)
],
"I will help you get the weather.",
),
],
)
def test_extract_tool_calls(glm4_moe_tool_parser, model_output,
expected_tool_calls, expected_content):
def test_extract_tool_calls(
glm4_moe_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
@@ -202,7 +235,8 @@ I will help you get the weather.
</tool_call>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
@@ -224,7 +258,8 @@ def test_extract_tool_calls_malformed_xml(glm4_moe_tool_parser):
</tool_call>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
# Should handle malformed XML gracefully
# The parser should either extract what it can or return no tool calls
@@ -239,12 +274,12 @@ def test_extract_tool_calls_empty_arguments(glm4_moe_tool_parser):
</tool_call>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[
0].function.name == "get_current_time"
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_time"
# Empty arguments should result in empty JSON object
assert extracted_tool_calls.tool_calls[0].function.arguments == "{}"
@@ -270,7 +305,8 @@ meaningwhile, I will also check the weather in Shanghai.
</tool_call>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 2
@@ -321,8 +357,7 @@ def test_streaming_basic_functionality(glm4_moe_tool_parser):
# The result behavior depends on the streaming state
# This test mainly ensures no exceptions are thrown
assert result is None or hasattr(result, 'tool_calls') or hasattr(
result, 'content')
assert result is None or hasattr(result, "tool_calls") or hasattr(result, "content")
def test_streaming_no_tool_calls(glm4_moe_tool_parser):
@@ -341,7 +376,7 @@ def test_streaming_no_tool_calls(glm4_moe_tool_parser):
# Should return the delta text as content
assert result is not None
assert hasattr(result, 'content')
assert hasattr(result, "content")
assert result.content == " without any tool calls."
@@ -367,7 +402,7 @@ def test_streaming_with_content_before_tool_calls(glm4_moe_tool_parser):
# Should return content when no tool call tokens are detected
assert result is not None
assert hasattr(result, 'content')
assert hasattr(result, "content")
assert result.content == "get the weather.<tool_call>"
@@ -383,7 +418,8 @@ def test_extract_tool_calls_special_characters(glm4_moe_tool_parser):
</tool_call>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
@@ -404,7 +440,8 @@ def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser):
<arg_value>2025-08-01</arg_value>"""
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
# Incomplete tool calls should not be extracted
assert not extracted_tool_calls.tools_called

View File

@@ -9,8 +9,7 @@ import partial_json_parser
import pytest
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import JambaToolParser
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
@@ -30,12 +29,14 @@ def jamba_tool_parser(jamba_tokenizer):
return JambaToolParser(jamba_tokenizer)
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16
@@ -44,10 +45,9 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
def stream_delta_message_generator(
jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer,
model_output: str) -> Generator[DeltaMessage, None, None]:
all_token_ids = jamba_tokenizer.encode(model_output,
add_special_tokens=False)
jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, model_output: str
) -> Generator[DeltaMessage, None, None]:
all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False)
previous_text = ""
previous_tokens = None
@@ -56,18 +56,19 @@ def stream_delta_message_generator(
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=jamba_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=jamba_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
@@ -84,8 +85,9 @@ def stream_delta_message_generator(
yield delta_message
previous_text = current_text
previous_tokens = previous_tokens + new_tokens if previous_tokens\
else new_tokens
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
@@ -93,7 +95,8 @@ def stream_delta_message_generator(
def test_extract_tool_calls_no_tools(jamba_tool_parser):
model_output = "This is a test"
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@@ -108,54 +111,63 @@ def test_extract_tool_calls_no_tools(jamba_tool_parser):
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
None),
None,
),
(
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
""" Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
" Sure! let me call the tool for you."),
" Sure! let me call the tool for you.",
),
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit"
})))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
),
)
),
],
None)
None,
),
],
)
def test_extract_tool_calls(jamba_tool_parser, model_output,
expected_tool_calls, expected_content):
def test_extract_tool_calls(
jamba_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
@@ -172,63 +184,75 @@ def test_extract_tool_calls(jamba_tool_parser, model_output,
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
('''This is a test''', [], '''This is a test'''),
("""This is a test""", [], """This is a test"""),
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
" "),
" ",
),
(
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
""" Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
" Sure! let me call the tool for you."),
" Sure! let me call the tool for you.",
),
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit"
})))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
),
)
),
],
" ")
" ",
),
],
)
def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer,
model_output, expected_tool_calls,
expected_content):
other_content: str = ''
def test_extract_tool_calls_streaming(
jamba_tool_parser,
jamba_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
other_content: str = ""
function_names: list[str] = []
function_args_strs: list[str] = []
tool_call_idx: int = -1
tool_call_ids: list[Optional[str]] = []
for delta_message in stream_delta_message_generator(
jamba_tool_parser, jamba_tokenizer, model_output):
jamba_tool_parser, jamba_tokenizer, model_output
):
# role should never be streamed from tool parser
assert not delta_message.role
@@ -264,18 +288,22 @@ def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer,
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)
function_args_strs[
tool_call.index] += tool_call.function.arguments
function_args_strs[tool_call.index] += tool_call.function.arguments
assert other_content == expected_content
actual_tool_calls = [
ToolCall(id=tool_call_id,
function=FunctionCall(
name=function_name,
arguments=partial_json_parser.ensure_json(
function_args_str, Allow.OBJ | Allow.STR)))
ToolCall(
id=tool_call_id,
function=FunctionCall(
name=function_name,
arguments=partial_json_parser.ensure_json(
function_args_str, Allow.OBJ | Allow.STR
),
),
)
for tool_call_id, function_name, function_args_str in zip(
tool_call_ids, function_names, function_args_strs)
tool_call_ids, function_names, function_args_strs
)
]
assert_tool_calls(actual_tool_calls, expected_tool_calls)

View File

@@ -26,27 +26,31 @@ def kimi_k2_tool_parser(kimi_k2_tokenizer):
return KimiK2ToolParser(kimi_k2_tokenizer)
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
# assert tool call id format
assert actual_tool_call.id.startswith("functions.")
assert actual_tool_call.id.split(':')[-1].isdigit()
assert actual_tool_call.id.split('.')[1].split(
':')[0] == expected_tool_call.function.name
assert actual_tool_call.id.split(":")[-1].isdigit()
assert (
actual_tool_call.id.split(".")[1].split(":")[0]
== expected_tool_call.function.name
)
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
model_output = "This is a test"
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@@ -63,14 +67,18 @@ def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(id='functions.get_weather:0',
function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"city": "Beijing",
}, ),
),
type='function')
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Beijing",
},
),
),
type="function",
)
],
"I'll help you check the weather. ",
),
@@ -79,31 +87,41 @@ functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(id='functions.get_weather:0',
function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"city": "Beijing",
}, ),
),
type='function'),
ToolCall(id='functions.get_weather:1',
function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"city": "Shanghai",
}, ),
),
type='function')
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Beijing",
},
),
),
type="function",
),
ToolCall(
id="functions.get_weather:1",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Shanghai",
},
),
),
type="function",
),
],
"I'll help you check the weather. ",
),
],
)
def test_extract_tool_calls(kimi_k2_tool_parser, model_output,
expected_tool_calls, expected_content):
def test_extract_tool_calls(
kimi_k2_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
@@ -118,15 +136,14 @@ functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 2
assert extracted_tool_calls.tool_calls[
0].function.name == "invalid_get_weather"
assert extracted_tool_calls.tool_calls[
1].function.name == "valid_get_weather"
assert extracted_tool_calls.tool_calls[0].function.name == "invalid_get_weather"
assert extracted_tool_calls.tool_calls[1].function.name == "valid_get_weather"
def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser):
@@ -136,13 +153,13 @@ functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"}
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[
0].function.name == "valid_get_weather"
assert extracted_tool_calls.tool_calls[0].function.name == "valid_get_weather"
def test_streaming_basic_functionality(kimi_k2_tool_parser):
@@ -170,8 +187,7 @@ functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_
# The result might be None or contain tool call information
# This depends on the internal state management
if result is not None and hasattr(result,
'tool_calls') and result.tool_calls:
if result is not None and hasattr(result, "tool_calls") and result.tool_calls:
assert len(result.tool_calls) >= 0
@@ -191,5 +207,5 @@ def test_streaming_no_tool_calls(kimi_k2_tool_parser):
# Should return the delta text as content
assert result is not None
assert hasattr(result, 'content')
assert hasattr(result, "content")
assert result.content == " without any tool calls."

File diff suppressed because it is too large Load Diff

View File

@@ -4,9 +4,15 @@
import json
import pytest
from openai_harmony import (Conversation, DeveloperContent,
HarmonyEncodingName, Message, Role, SystemContent,
load_harmony_encoding)
from openai_harmony import (
Conversation,
DeveloperContent,
HarmonyEncodingName,
Message,
Role,
SystemContent,
load_harmony_encoding,
)
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser
@@ -37,8 +43,9 @@ def assert_tool_calls(
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16 # Default from protocol.py
assert actual_tool_call.type == "function"
@@ -46,20 +53,25 @@ def assert_tool_calls(
def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
convo = Conversation.from_messages([
Message.from_role_and_content(
Role.SYSTEM,
SystemContent.new(),
),
Message.from_role_and_content(
Role.DEVELOPER,
DeveloperContent.new().with_instructions("Talk like a pirate!")),
Message.from_role_and_content(Role.USER, "Arrr, how be you?"),
Message.from_role_and_content(Role.ASSISTANT,
"This is a test").with_channel("final")
])
convo = Conversation.from_messages(
[
Message.from_role_and_content(
Role.SYSTEM,
SystemContent.new(),
),
Message.from_role_and_content(
Role.DEVELOPER,
DeveloperContent.new().with_instructions("Talk like a pirate!"),
),
Message.from_role_and_content(Role.USER, "Arrr, how be you?"),
Message.from_role_and_content(
Role.ASSISTANT, "This is a test"
).with_channel("final"),
]
)
token_ids = harmony_encoding.render_conversation_for_completion(
convo, Role.ASSISTANT)
convo, Role.ASSISTANT
)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
@@ -70,26 +82,32 @@ def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
assert extracted_info.content == "This is a test"
@pytest.mark.parametrize("tool_args", [
'{"location": "Tokyo"}',
'{\n"location": "Tokyo"\n}',
])
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding,
tool_args):
convo = Conversation.from_messages([
Message.from_role_and_content(Role.USER,
"What is the weather in Tokyo?"),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT,
tool_args).with_channel("commentary").with_recipient(
"functions.get_current_weather").with_content_type("json"),
])
@pytest.mark.parametrize(
"tool_args",
[
'{"location": "Tokyo"}',
'{\n"location": "Tokyo"\n}',
],
)
def test_extract_tool_calls_single_tool(
openai_tool_parser, harmony_encoding, tool_args
):
convo = Conversation.from_messages(
[
Message.from_role_and_content(Role.USER, "What is the weather in Tokyo?"),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(Role.ASSISTANT, tool_args)
.with_channel("commentary")
.with_recipient("functions.get_current_weather")
.with_content_type("json"),
]
)
token_ids = harmony_encoding.render_conversation_for_completion(
convo, Role.ASSISTANT)
convo, Role.ASSISTANT
)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
@@ -98,10 +116,12 @@ def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding,
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)
)
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content is None
@@ -111,33 +131,39 @@ def test_extract_tool_calls_multiple_tools(
openai_tool_parser,
harmony_encoding,
):
convo = Conversation.from_messages([
Message.from_role_and_content(
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
"functions.get_current_weather").with_content_type("json"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
"functions.get_user_location").with_content_type("json"),
Message.from_role_and_content(
Role.ASSISTANT, '{"location": "Tokyo"}').with_channel(
"commentary").with_recipient("functions.no_content_type"),
Message.from_role_and_content(Role.ASSISTANT, "foo").with_channel(
"commentary").with_recipient("functions.not_json_no_content_type"),
Message.from_role_and_content(
Role.ASSISTANT, '{}').with_channel("commentary").with_recipient(
"functions.empty_args").with_content_type("json"),
Message.from_role_and_content(
Role.ASSISTANT, '').with_channel("commentary").with_recipient(
"functions.no_args").with_content_type("json"),
])
convo = Conversation.from_messages(
[
Message.from_role_and_content(
Role.USER, "What is the weather in Tokyo based on where I'm at?"
),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
.with_channel("commentary")
.with_recipient("functions.get_current_weather")
.with_content_type("json"),
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
.with_channel("commentary")
.with_recipient("functions.get_user_location")
.with_content_type("json"),
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
.with_channel("commentary")
.with_recipient("functions.no_content_type"),
Message.from_role_and_content(Role.ASSISTANT, "foo")
.with_channel("commentary")
.with_recipient("functions.not_json_no_content_type"),
Message.from_role_and_content(Role.ASSISTANT, "{}")
.with_channel("commentary")
.with_recipient("functions.empty_args")
.with_content_type("json"),
Message.from_role_and_content(Role.ASSISTANT, "")
.with_channel("commentary")
.with_recipient("functions.no_args")
.with_content_type("json"),
]
)
token_ids = harmony_encoding.render_conversation_for_completion(
convo,
Role.ASSISTANT,
@@ -150,30 +176,42 @@ def test_extract_tool_calls_multiple_tools(
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)),
ToolCall(function=FunctionCall(
name="get_user_location",
arguments=json.dumps({"location": "Tokyo"}),
)),
ToolCall(function=FunctionCall(
name="no_content_type",
arguments=json.dumps({"location": "Tokyo"}),
)),
ToolCall(function=FunctionCall(
name="not_json_no_content_type",
arguments="foo",
)),
ToolCall(function=FunctionCall(
name="empty_args",
arguments=json.dumps({}),
)),
ToolCall(function=FunctionCall(
name="no_args",
arguments="",
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)
),
ToolCall(
function=FunctionCall(
name="get_user_location",
arguments=json.dumps({"location": "Tokyo"}),
)
),
ToolCall(
function=FunctionCall(
name="no_content_type",
arguments=json.dumps({"location": "Tokyo"}),
)
),
ToolCall(
function=FunctionCall(
name="not_json_no_content_type",
arguments="foo",
)
),
ToolCall(
function=FunctionCall(
name="empty_args",
arguments=json.dumps({}),
)
),
ToolCall(
function=FunctionCall(
name="no_args",
arguments="",
)
),
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content is None
@@ -184,20 +222,24 @@ def test_extract_tool_calls_with_content(
harmony_encoding,
):
final_content = "This tool call will get the weather."
convo = Conversation.from_messages([
Message.from_role_and_content(
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
"functions.get_current_weather").with_content_type("json"),
Message.from_role_and_content(Role.ASSISTANT,
final_content).with_channel("final"),
])
convo = Conversation.from_messages(
[
Message.from_role_and_content(
Role.USER, "What is the weather in Tokyo based on where I'm at?"
),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
.with_channel("commentary")
.with_recipient("functions.get_current_weather")
.with_content_type("json"),
Message.from_role_and_content(Role.ASSISTANT, final_content).with_channel(
"final"
),
]
)
token_ids = harmony_encoding.render_conversation_for_completion(
convo,
Role.ASSISTANT,
@@ -210,10 +252,12 @@ def test_extract_tool_calls_with_content(
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)
),
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content == final_content

View File

@@ -7,9 +7,13 @@ from typing import Optional
import openai
import pytest
from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
WEATHER_TOOL, ServerConfig)
from .utils import (
MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
SEARCH_TOOL,
WEATHER_TOOL,
ServerConfig,
)
# test: getting the model to generate parallel tool calls (streaming/not)
@@ -17,12 +21,15 @@ from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
# may be added in the future. e.g. llama 3.1 models are not designed to support
# parallel tool calls.
@pytest.mark.asyncio
async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
server_config: ServerConfig):
async def test_parallel_tool_calls(
client: openai.AsyncOpenAI, server_config: ServerConfig
):
if not server_config.get("supports_parallel", True):
pytest.skip("The {} model doesn't support parallel tool calls".format(
server_config["model"]))
pytest.skip(
"The {} model doesn't support parallel tool calls".format(
server_config["model"]
)
)
models = await client.models.list()
model_name: str = models.data[0].id
@@ -32,7 +39,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
max_completion_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
logprobs=False,
)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
@@ -69,7 +77,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
max_completion_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
stream=True,
)
role_name: Optional[str] = None
finish_reason_count: int = 0
@@ -80,24 +89,22 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
tool_call_id_count: int = 0
async for chunk in stream:
# if there's a finish reason make sure it's tools
if chunk.choices[0].finish_reason:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == 'tool_calls'
assert chunk.choices[0].finish_reason == "tool_calls"
# if a role is being streamed make sure it wasn't already set to
# something else
if chunk.choices[0].delta.role:
assert not role_name or role_name == 'assistant'
role_name = 'assistant'
assert not role_name or role_name == "assistant"
role_name = "assistant"
# if a tool call is streamed make sure there's exactly one
# (based on the request parameters
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
# make sure only one diff is present - correct even for parallel
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
@@ -110,8 +117,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id:
tool_call_id_count += 1
assert (isinstance(tool_call.id, str)
and (len(tool_call.id) >= 9))
assert isinstance(tool_call.id, str) and (len(tool_call.id) >= 9)
# if parts of the function start being streamed
if tool_call.function:
@@ -125,32 +131,32 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)
tool_call_args[
tool_call.index] += tool_call.function.arguments
tool_call_args[tool_call.index] += tool_call.function.arguments
assert finish_reason_count == 1
assert role_name == 'assistant'
assert role_name == "assistant"
assert (len(non_streamed_tool_calls) == len(tool_call_names) ==
len(tool_call_args))
assert len(non_streamed_tool_calls) == len(tool_call_names) == len(tool_call_args)
for i in range(2):
assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
streamed_args = json.loads(tool_call_args[i])
non_streamed_args = json.loads(
non_streamed_tool_calls[i].function.arguments)
non_streamed_args = json.loads(non_streamed_tool_calls[i].function.arguments)
assert streamed_args == non_streamed_args
# test: providing parallel tool calls back to the model to get a response
# (streaming/not)
@pytest.mark.asyncio
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
server_config: ServerConfig):
async def test_parallel_tool_calls_with_results(
client: openai.AsyncOpenAI, server_config: ServerConfig
):
if not server_config.get("supports_parallel", True):
pytest.skip("The {} model doesn't support parallel tool calls".format(
server_config["model"]))
pytest.skip(
"The {} model doesn't support parallel tool calls".format(
server_config["model"]
)
)
models = await client.models.list()
model_name: str = models.data[0].id
@@ -160,14 +166,14 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
max_completion_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
logprobs=False,
)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 0
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
assert "98" in choice.message.content # Dallas temp in tool response
assert "78" in choice.message.content # Orlando temp in tool response
@@ -179,7 +185,8 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
stream=True,
)
chunks: list[str] = []
finish_reason_count = 0

View File

@@ -7,14 +7,17 @@ from typing import Optional
import pytest
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaMessage,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
Qwen3CoderToolParser)
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import (
Qwen3XMLToolParser)
Qwen3CoderToolParser,
)
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
@@ -39,8 +42,7 @@ def qwen3_xml_tool_parser(qwen3_tokenizer):
@pytest.fixture(params=["original", "xml"])
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser,
request):
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request):
"""Parameterized fixture that provides both parser types for testing"""
if request.param == "original":
return qwen3_tool_parser
@@ -51,76 +53,63 @@ def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser,
@pytest.fixture
def sample_tools():
return [
ChatCompletionToolsParam(type="function",
function={
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
},
"state": {
"type": "string",
"description":
"The state code"
},
"unit": {
"type": "string",
"enum":
["fahrenheit", "celsius"]
}
},
"required": ["city", "state"]
}
}),
ChatCompletionToolsParam(type="function",
function={
"name": "calculate_area",
"description":
"Calculate area of a shape",
"parameters": {
"type": "object",
"properties": {
"shape": {
"type": "string"
},
"dimensions": {
"type": "object"
},
"precision": {
"type": "integer"
}
}
}
})
ChatCompletionToolsParam(
type="function",
function={
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city name"},
"state": {"type": "string", "description": "The state code"},
"unit": {"type": "string", "enum": ["fahrenheit", "celsius"]},
},
"required": ["city", "state"],
},
},
),
ChatCompletionToolsParam(
type="function",
function={
"name": "calculate_area",
"description": "Calculate area of a shape",
"parameters": {
"type": "object",
"properties": {
"shape": {"type": "string"},
"dimensions": {"type": "object"},
"precision": {"type": "integer"},
},
},
},
),
]
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
# Qwen3 parser doesn't generate IDs during extraction
assert actual_tool_call.type == "function"
assert (
actual_tool_call.function.name == expected_tool_call.function.name)
assert (json.loads(actual_tool_call.function.arguments) == json.loads(
expected_tool_call.function.arguments))
assert actual_tool_call.function.name == expected_tool_call.function.name
assert json.loads(actual_tool_call.function.arguments) == json.loads(
expected_tool_call.function.arguments
)
def stream_delta_message_generator(
qwen3_tool_parser,
qwen3_tokenizer: AnyTokenizer,
model_output: str,
request: Optional[ChatCompletionRequest] = None
request: Optional[ChatCompletionRequest] = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = qwen3_tokenizer.encode(model_output,
add_special_tokens=False)
all_token_ids = qwen3_tokenizer.encode(model_output, add_special_tokens=False)
previous_text = ""
previous_tokens = None
@@ -129,18 +118,19 @@ def stream_delta_message_generator(
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=qwen3_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=qwen3_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
@@ -157,8 +147,9 @@ def stream_delta_message_generator(
yield delta_message
previous_text = current_text
previous_tokens = (previous_tokens +
new_tokens if previous_tokens else new_tokens)
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
@@ -166,7 +157,8 @@ def stream_delta_message_generator(
def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized):
model_output = "This is a test response without any tool calls"
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@@ -182,7 +174,8 @@ def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized):
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
('''<tool_call>
(
"""<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
@@ -194,16 +187,21 @@ TX
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], None),
('''Sure! Let me check the weather for you.<tool_call>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
None,
),
(
"""Sure! Let me check the weather for you.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
@@ -215,16 +213,21 @@ TX
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], "Sure! Let me check the weather for you."),
('''<tool_call>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
"Sure! Let me check the weather for you.",
),
(
"""<tool_call>
<function=calculate_area>
<parameter=shape>
rectangle
@@ -237,18 +240,25 @@ rectangle
2
</parameter>
</function>
</tool_call>''', [
ToolCall(function=FunctionCall(name="calculate_area",
arguments=json.dumps({
"shape": "rectangle",
"dimensions": {
"width": 10,
"height": 20
},
"precision": 2
})))
], None),
('''<tool_call>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="calculate_area",
arguments=json.dumps(
{
"shape": "rectangle",
"dimensions": {"width": 10, "height": 20},
"precision": 2,
}
),
)
)
],
None,
),
(
"""<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
@@ -273,23 +283,29 @@ FL
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit"
})))
], None),
('''Let me calculate that area for you.<tool_call>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
),
)
),
],
None,
),
(
"""Let me calculate that area for you.<tool_call>
<function=calculate_area>
<parameter=shape>
circle
@@ -301,26 +317,36 @@ circle
3
</parameter>
</function>
</tool_call>''', [
ToolCall(function=FunctionCall(name="calculate_area",
arguments=json.dumps({
"shape": "circle",
"dimensions": {
"radius": 15.5
},
"precision": 3
})))
], "Let me calculate that area for you."),
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="calculate_area",
arguments=json.dumps(
{
"shape": "circle",
"dimensions": {"radius": 15.5},
"precision": 3,
}
),
)
)
],
"Let me calculate that area for you.",
),
],
)
def test_extract_tool_calls(qwen3_tool_parser_parametrized, sample_tools,
model_output, expected_tool_calls,
expected_content):
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
def test_extract_tool_calls(
qwen3_tool_parser_parametrized,
sample_tools,
model_output,
expected_tool_calls,
expected_content,
):
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request)
model_output, request=request
)
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
@@ -328,60 +354,51 @@ def test_extract_tool_calls(qwen3_tool_parser_parametrized, sample_tools,
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser_parametrized,
sample_tools):
def test_extract_tool_calls_fallback_no_tags(
qwen3_tool_parser_parametrized, sample_tools
):
"""Test fallback parsing when XML tags are missing"""
model_output = '''<function=get_current_weather>
model_output = """<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>'''
</function>"""
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request)
model_output, request=request
)
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
assert (extracted_tool_calls.tool_calls[0].function.name ==
"get_current_weather")
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized):
"""Test parameter type conversion based on tool schema"""
tools = [
ChatCompletionToolsParam(type="function",
function={
"name": "test_types",
"parameters": {
"type": "object",
"properties": {
"int_param": {
"type": "integer"
},
"float_param": {
"type": "float"
},
"bool_param": {
"type": "boolean"
},
"str_param": {
"type": "string"
},
"obj_param": {
"type": "object"
}
}
}
})
ChatCompletionToolsParam(
type="function",
function={
"name": "test_types",
"parameters": {
"type": "object",
"properties": {
"int_param": {"type": "integer"},
"float_param": {"type": "float"},
"bool_param": {"type": "boolean"},
"str_param": {"type": "string"},
"obj_param": {"type": "object"},
},
},
},
)
]
model_output = '''<tool_call>
model_output = """<tool_call>
<function=test_types>
<parameter=int_param>
42
@@ -399,11 +416,12 @@ hello world
{"key": "value"}
</parameter>
</function>
</tool_call>'''
</tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request)
model_output, request=request
)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["int_param"] == 42
@@ -425,7 +443,8 @@ hello world
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("This is a test without tools", [], "This is a test without tools"),
('''<tool_call>
(
"""<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
@@ -437,16 +456,21 @@ TX
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], None),
('''Sure! Let me check the weather for you.<tool_call>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
None,
),
(
"""Sure! Let me check the weather for you.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
@@ -458,16 +482,21 @@ TX
fahrenheit
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
], "Sure! Let me check the weather for you."),
('''<tool_call>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
)
],
"Sure! Let me check the weather for you.",
),
(
"""<tool_call>
<function=calculate_area>
<parameter=shape>
rectangle
@@ -480,18 +509,25 @@ rectangle
2
</parameter>
</function>
</tool_call>''', [
ToolCall(function=FunctionCall(name="calculate_area",
arguments=json.dumps({
"shape": "rectangle",
"dimensions": {
"width": 10,
"height": 20
},
"precision": 2
})))
], None),
('''<tool_call>
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="calculate_area",
arguments=json.dumps(
{
"shape": "rectangle",
"dimensions": {"width": 10, "height": 20},
"precision": 2,
}
),
)
)
],
None,
),
(
"""<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
@@ -516,24 +552,30 @@ FL
celsius
</parameter>
</function>
</tool_call>''', [
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(
function=FunctionCall(name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "celsius"
})))
], None),
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "Orlando", "state": "FL", "unit": "celsius"}
),
)
),
],
None,
),
# Added tool_with_typed_params test case
('''Let me calculate that area for you.<tool_call>
(
"""Let me calculate that area for you.<tool_call>
<function=calculate_area>
<parameter=shape>
circle
@@ -545,33 +587,42 @@ circle
3
</parameter>
</function>
</tool_call>''', [
ToolCall(function=FunctionCall(name="calculate_area",
arguments=json.dumps({
"shape": "circle",
"dimensions": {
"radius": 15.5
},
"precision": 3
})))
], "Let me calculate that area for you."),
</tool_call>""",
[
ToolCall(
function=FunctionCall(
name="calculate_area",
arguments=json.dumps(
{
"shape": "circle",
"dimensions": {"radius": 15.5},
"precision": 3,
}
),
)
)
],
"Let me calculate that area for you.",
),
],
)
def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized,
qwen3_tokenizer, sample_tools,
model_output, expected_tool_calls,
expected_content):
def test_extract_tool_calls_streaming(
qwen3_tool_parser_parametrized,
qwen3_tokenizer,
sample_tools,
model_output,
expected_tool_calls,
expected_content,
):
"""Test incremental streaming behavior including typed parameters"""
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ''
other_content = ""
tool_states = {} # Track state per tool index
for delta_message in stream_delta_message_generator(
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
request):
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
):
# role should never be streamed from tool parser
assert not delta_message.role
@@ -588,7 +639,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized,
"id": None,
"name": None,
"arguments": "",
"type": None
"type": None,
}
# First chunk should have id, name, and type
@@ -607,8 +658,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized,
if tool_call.function.arguments is not None:
# Accumulate arguments incrementally
tool_states[idx][
"arguments"] += tool_call.function.arguments
tool_states[idx]["arguments"] += tool_call.function.arguments
# Verify final content
assert other_content == (expected_content or "") # Handle None case
@@ -632,10 +682,11 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized,
def test_extract_tool_calls_missing_closing_parameter_tag(
qwen3_tool_parser_parametrized, sample_tools):
qwen3_tool_parser_parametrized, sample_tools
):
"""Test handling of missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML
model_output = '''Let me check the weather for you:
model_output = """Let me check the weather for you:
<tool_call>
<function=get_current_weather>
<parameter=city>
@@ -647,21 +698,19 @@ TX
fahrenheit
</parameter>
</function>
</tool_call>'''
</tool_call>"""
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request)
model_output, request=request
)
# The parser should handle the malformed XML gracefully
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
# Verify the function name is correct
assert extracted_tool_calls.tool_calls[
0].function.name == "get_current_weather"
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
# Verify the arguments are parsed despite the missing closing tag
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
@@ -675,10 +724,11 @@ fahrenheit
def test_extract_tool_calls_streaming_missing_closing_tag(
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools):
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
):
"""Test streaming with missing closing </parameter> tag"""
# Using get_current_weather from sample_tools but with malformed XML
model_output = '''Let me check the weather for you:
model_output = """Let me check the weather for you:
<tool_call>
<function=get_current_weather>
<parameter=city>
@@ -690,19 +740,16 @@ TX
fahrenheit
</parameter>
</function>
</tool_call>'''
</tool_call>"""
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ''
other_content = ""
tool_states = {}
for delta_message in stream_delta_message_generator(
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
request):
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
):
if delta_message.content:
other_content += delta_message.content
@@ -715,7 +762,7 @@ fahrenheit
"id": None,
"name": None,
"arguments": "",
"type": None
"type": None,
}
if tool_call.id:
@@ -730,8 +777,7 @@ fahrenheit
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx][
"arguments"] += tool_call.function.arguments
tool_states[idx]["arguments"] += tool_call.function.arguments
# Verify content was streamed
assert "Let me check the weather for you:" in other_content
@@ -752,9 +798,10 @@ fahrenheit
def test_extract_tool_calls_streaming_incremental(
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools):
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
):
"""Test that streaming is truly incremental"""
model_output = '''I'll check the weather.<tool_call>
model_output = """I'll check the weather.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
@@ -763,16 +810,14 @@ Dallas
TX
</parameter>
</function>
</tool_call>'''
</tool_call>"""
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
chunks = []
for delta_message in stream_delta_message_generator(
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
request):
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
):
chunks.append(delta_message)
# Should have multiple chunks
@@ -787,7 +832,7 @@ TX
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert (chunk.tool_calls[0].function.name == "get_current_weather")
assert chunk.tool_calls[0].function.name == "get_current_weather"
assert chunk.tool_calls[0].type == "function"
# Empty initially
assert chunk.tool_calls[0].function.arguments == ""
@@ -811,46 +856,40 @@ TX
def test_extract_tool_calls_complex_type_with_single_quote(
qwen3_tool_parser_parametrized):
qwen3_tool_parser_parametrized,
):
"""Test parameter type conversion based on tool schema"""
tools = [
ChatCompletionToolsParam(type="function",
function={
"name": "test_types",
"parameters": {
"type": "object",
"properties": {
"int_param": {
"type": "integer"
},
"float_param": {
"type": "float"
},
"bool_param": {
"type": "boolean"
},
"str_param": {
"type": "string"
},
"obj_param": {
"type": "object"
}
}
}
})
ChatCompletionToolsParam(
type="function",
function={
"name": "test_types",
"parameters": {
"type": "object",
"properties": {
"int_param": {"type": "integer"},
"float_param": {"type": "float"},
"bool_param": {"type": "boolean"},
"str_param": {"type": "string"},
"obj_param": {"type": "object"},
},
},
},
)
]
model_output = '''<tool_call>
model_output = """<tool_call>
<function=test_types>
<parameter=obj_param>
{'key': 'value'}
</parameter>
</function>
</tool_call>'''
</tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
model_output, request=request)
model_output, request=request
)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["obj_param"] == {"key": "value"}

View File

@@ -8,10 +8,13 @@ from typing import Optional
import pytest
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaMessage,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
@@ -45,51 +48,56 @@ def sample_tools():
"properties": {
"location": {
"type": "string",
"description":
"City and country e.g. Bogotá, Colombia"
"description": "City and country e.g. Bogotá, Colombia",
},
"unit": {
"type": "string",
"description": "this is the unit of temperature"
}
"description": "this is the unit of temperature",
},
},
"required": ["location"],
"additionalProperties": False
"additionalProperties": False,
},
"returns": {
"type": "object",
"properties": {
"temperature": {
"type": "number",
"description": "temperature in celsius"
"description": "temperature in celsius",
}
},
"required": ["temperature"],
"additionalProperties": False
"additionalProperties": False,
},
"strict": True
}),
"strict": True,
},
),
]
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
# Seed-OSS tool call will not generate id
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
assert actual_tool_call.function.name == expected_tool_call.function.name
assert actual_tool_call.function.arguments == expected_tool_call.function.arguments
assert (
actual_tool_call.function.arguments == expected_tool_call.function.arguments
)
def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
model_output = "This is a test response without any tool calls"
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
@@ -104,17 +112,24 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""<seed:tool_call>\n<function=get_weather>\n"""
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
], None),
(
"""<seed:tool_call>\n<function=get_weather>\n"""
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
},
),
),
type="function",
)
],
None,
),
(
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
@@ -131,13 +146,17 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
"""\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
},
),
),
type="function",
)
],
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
@@ -169,15 +188,18 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
"unit": "celsius",
}, ),
),
type='function')
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
"unit": "celsius",
},
),
),
type="function",
)
],
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
@@ -196,13 +218,17 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
),
],
)
def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output,
expected_tool_calls, expected_content):
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
def test_extract_tool_calls(
seed_oss_tool_parser,
sample_tools,
model_output,
expected_tool_calls,
expected_content,
):
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
model_output, request=request) # type: ignore[arg-type]
model_output, request=request
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
@@ -225,7 +251,7 @@ def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
# Should return the delta text as content
assert result is not None
assert hasattr(result, 'content')
assert hasattr(result, "content")
assert result.content == " without any tool calls."
@@ -233,10 +259,9 @@ def stream_delta_message_generator(
seed_oss_tool_parser: SeedOssToolParser,
seed_oss_tokenizer: AnyTokenizer,
model_output: str,
request: Optional[ChatCompletionRequest] = None
request: Optional[ChatCompletionRequest] = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = seed_oss_tokenizer.encode(model_output,
add_special_tokens=False)
all_token_ids = seed_oss_tokenizer.encode(model_output, add_special_tokens=False)
previous_text = ""
previous_tokens = None
@@ -245,18 +270,19 @@ def stream_delta_message_generator(
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=seed_oss_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=seed_oss_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
@@ -273,8 +299,9 @@ def stream_delta_message_generator(
yield delta_message
previous_text = current_text
previous_tokens = (previous_tokens +
new_tokens if previous_tokens else new_tokens)
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
@@ -287,22 +314,27 @@ def stream_delta_message_generator(
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n"""
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
],
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
),
(
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n"""
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
},
),
),
type="function",
)
],
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""",
),
(
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
@@ -319,13 +351,17 @@ def stream_delta_message_generator(
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
"""\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
},
),
),
type="function",
)
],
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
@@ -357,15 +393,18 @@ def stream_delta_message_generator(
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
"unit": "celsius",
}, ),
),
type='function')
ToolCall(
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
"unit": "celsius",
},
),
),
type="function",
)
],
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
@@ -384,19 +423,23 @@ def stream_delta_message_generator(
),
],
)
def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
sample_tools, model_output, expected_tool_calls,
expected_content):
def test_streaming_tool_calls(
seed_oss_tool_parser,
seed_oss_tokenizer,
sample_tools,
model_output,
expected_tool_calls,
expected_content,
):
"""Test incremental streaming behavior"""
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ''
other_content = ""
tool_states = {} # Track state per tool index
for delta_message in stream_delta_message_generator(
seed_oss_tool_parser, seed_oss_tokenizer, model_output, request):
seed_oss_tool_parser, seed_oss_tokenizer, model_output, request
):
# role should never be streamed from tool parser
assert not delta_message.role
@@ -413,7 +456,7 @@ def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
"id": None,
"name": None,
"arguments": "",
"type": None
"type": None,
}
# First chunk should have id, name, and type
@@ -432,8 +475,7 @@ def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
if tool_call.function.arguments is not None:
# Accumulate arguments incrementally
tool_states[idx][
"arguments"] += tool_call.function.arguments
tool_states[idx]["arguments"] += tool_call.function.arguments
# Verify final content
assert other_content == expected_content

View File

@@ -7,8 +7,12 @@ from typing import Optional
import openai
import pytest
from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE,
SEARCH_TOOL, WEATHER_TOOL)
from .utils import (
MESSAGES_ASKING_FOR_TOOLS,
MESSAGES_WITH_TOOL_RESPONSE,
SEARCH_TOOL,
WEATHER_TOOL,
)
# test: request a chat completion that should return tool calls, so we know they
@@ -23,17 +27,18 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
logprobs=False,
)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
tool_calls = chat_completion.choices[0].message.tool_calls
# make sure a tool call is present
assert choice.message.role == 'assistant'
assert choice.message.role == "assistant"
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0].type == 'function'
assert tool_calls[0].type == "function"
assert tool_calls[0].function is not None
assert isinstance(tool_calls[0].id, str)
assert len(tool_calls[0].id) >= 9
@@ -54,7 +59,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
assert stop_reason == "tool_calls"
function_name: Optional[str] = None
function_args_str: str = ''
function_args_str: str = ""
tool_call_id: Optional[str] = None
role_name: Optional[str] = None
finish_reason_count: int = 0
@@ -67,20 +72,21 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
max_completion_tokens=100,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
stream=True,
)
async for chunk in stream:
assert chunk.choices[0].index == 0
if chunk.choices[0].finish_reason:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == 'tool_calls'
assert chunk.choices[0].finish_reason == "tool_calls"
# if a role is being streamed make sure it wasn't already set to
# something else
if chunk.choices[0].delta.role:
assert not role_name or role_name == 'assistant'
role_name = 'assistant'
assert not role_name or role_name == "assistant"
role_name = "assistant"
# if a tool call is streamed make sure there's exactly one
# (based on the request parameters
@@ -108,7 +114,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
function_args_str += tool_call.function.arguments
assert finish_reason_count == 1
assert role_name == 'assistant'
assert role_name == "assistant"
assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9)
# validate the name and arguments
@@ -148,14 +154,14 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
logprobs=False,
)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 0
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
assert "98" in choice.message.content # the temperature from the response
@@ -166,7 +172,8 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
stream=True,
)
chunks: list[str] = []
finish_reason_count = 0

View File

@@ -8,8 +8,10 @@ import pytest
import regex as re
from pydantic import TypeAdapter
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
pytestmark = pytest.mark.cpu_test
@@ -24,18 +26,16 @@ EXAMPLE_TOOLS = [
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for"
"type": "string",
"description": "The city to find the weather for"
", e.g. 'San Francisco'",
},
},
"required": ["city"],
"additionalProperties": False
"additionalProperties": False,
},
},
"strict": True
"strict": True,
},
{
"type": "function",
@@ -46,35 +46,33 @@ EXAMPLE_TOOLS = [
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to get the forecast for, e.g. 'New York'",
"type": "string",
"description": "The city to get the forecast for, e.g. 'New York'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
"type": "integer",
"description": "Number of days to get the forecast for (1-7)",
},
},
"required": ["city", "days"],
"additionalProperties": False
"additionalProperties": False,
},
},
"strict": True
"strict": True,
},
]
def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
should_match: bool):
def _compile_and_check(
tools: list[ChatCompletionToolsParam], sample_output, should_match: bool
):
self = MagicMock(tool_choice="required", tools=tools)
schema = ChatCompletionRequest._get_json_schema_from_tool(self)
assert isinstance(schema, dict)
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
from outlines_core.json_schema import build_regex_from_schema
regex = build_regex_from_schema(json.dumps(schema))
compiled = re.compile(regex)
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
@@ -83,65 +81,31 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
VALID_TOOL_OUTPUTS = [
([{
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}], True),
([{
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Berlin"
}
}], True),
([{
"name": "get_forecast",
"parameters": {
"city": "Vienna",
"days": 7
}
}], True),
([{
"name": "get_forecast",
"parameters": {
"city": "Vienna",
"days": 7
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}], True),
([{
"name": "get_forecast",
"parameters": {
"city": "Vienna",
"days": 7
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}, {
"name": "get_forecast",
"parameters": {
"city": "Berlin",
"days": 7
}
}, {
"name": "get_current_weather",
"parameters": {
"city": "Berlin"
}
}], True),
([{"name": "get_current_weather", "parameters": {"city": "Vienna"}}], True),
(
[
{"name": "get_current_weather", "parameters": {"city": "Vienna"}},
{"name": "get_current_weather", "parameters": {"city": "Berlin"}},
],
True,
),
([{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}}], True),
(
[
{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}},
{"name": "get_current_weather", "parameters": {"city": "Vienna"}},
],
True,
),
(
[
{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}},
{"name": "get_current_weather", "parameters": {"city": "Vienna"}},
{"name": "get_forecast", "parameters": {"city": "Berlin", "days": 7}},
{"name": "get_current_weather", "parameters": {"city": "Berlin"}},
],
True,
),
]
VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
@@ -149,92 +113,100 @@ VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
@pytest.mark.parametrize(
"sample_output, should_match",
VALID_TOOL_OUTPUTS + [
VALID_TOOL_OUTPUTS
+ [
(None, False),
([], False), # empty list cannot be generated
({}, False), # empty object cannot be generated
([{}], False), # list with empty object cannot be generated
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather"
}],
False),
[
{ # function without required parameters cannot be generated
"name": "get_current_weather"
}
],
False,
),
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": {}
}],
False),
[
{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": {},
}
],
False,
),
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": None
}],
False),
[
{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": None,
}
],
False,
),
(
{ # tool call without lists cannot be generated
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
"parameters": {"city": "Vienna"},
},
False),
False,
),
(
[{ # tool call with extra parameters cannot be generated
"name": "get_current_weather",
"parameters": {
"city": "Vienna",
"extra": "value"
[
{ # tool call with extra parameters cannot be generated
"name": "get_current_weather",
"parameters": {"city": "Vienna", "extra": "value"},
}
}],
False),
],
False,
),
(
[{ # tool call where parameters are first cannot be generated
"parameters": {
"city": "Vienna"
},
"name": "get_current_weather"
}],
False),
(
[{ # tool call without all required parameters cannot be generated
"name": "get_forecast",
"parameters": {
"city": "Vienna"
[
{ # tool call where parameters are first cannot be generated
"parameters": {"city": "Vienna"},
"name": "get_current_weather",
}
}],
False),
],
False,
),
(
[
{ # tool call without all required parameters cannot be generated
"name": "get_forecast",
"parameters": {"city": "Vienna"},
}
],
False,
),
( # tool call with incorrect name/parameters cannot be generated
[{
"name": "get_weather",
"parameters": {
"city": "Vienna",
"days": 7
}
}], False),
[{"name": "get_weather", "parameters": {"city": "Vienna", "days": 7}}],
False,
),
( # tool call with both valid and empty function cannot be generated
[{
"name": "get_current_weather",
"parameters": {
"city": "Vienna"
}
}, {}], False),
])
[{"name": "get_current_weather", "parameters": {"city": "Vienna"}}, {}],
False,
),
],
)
def test_structured_outputs_json(sample_output, should_match):
_compile_and_check(tools=TypeAdapter(
list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS),
sample_output=sample_output,
should_match=should_match)
_compile_and_check(
tools=TypeAdapter(list[ChatCompletionToolsParam]).validate_python(
EXAMPLE_TOOLS
),
sample_output=sample_output,
should_match=should_match,
)
def update_parameters_none(
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
def update_parameters_none(tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
tool.function.parameters = None
return tool
def update_parameters_empty_dict(
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
tool: ChatCompletionToolsParam,
) -> ChatCompletionToolsParam:
tool.function.parameters = {}
return tool
@@ -247,48 +219,60 @@ def update_parameters_empty_dict(
({}, False), # empty object cannot be generated
([{}], False), # list with empty object cannot be generated
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather"
}],
False),
(
[{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": None
}],
False),
(
[{ # function with extra parameters cannot be generated
"name": "get_current_weather",
"parameters": {
"extra": "value"
[
{ # function without required parameters cannot be generated
"name": "get_current_weather"
}
}],
False),
],
False,
),
(
[{ # only function with empty parameters object is valid
"name": "get_current_weather",
"parameters": {}
}],
True),
])
[
{ # function without required parameters cannot be generated
"name": "get_current_weather",
"parameters": None,
}
],
False,
),
(
[
{ # function with extra parameters cannot be generated
"name": "get_current_weather",
"parameters": {"extra": "value"},
}
],
False,
),
(
[
{ # only function with empty parameters object is valid
"name": "get_current_weather",
"parameters": {},
}
],
True,
),
],
)
@pytest.mark.parametrize(
"update_parameters",
[update_parameters_none, update_parameters_empty_dict])
def test_structured_outputs_json_without_parameters(sample_output,
should_match,
update_parameters):
"update_parameters", [update_parameters_none, update_parameters_empty_dict]
)
def test_structured_outputs_json_without_parameters(
sample_output, should_match, update_parameters
):
updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
tools = TypeAdapter(
list[ChatCompletionToolsParam]).validate_python(updated_tools)
tools = TypeAdapter(list[ChatCompletionToolsParam]).validate_python(updated_tools)
tools = list(map(update_parameters, tools))
assert all([
tool.function.parameters is None or tool.function.parameters == {}
for tool in tools
])
_compile_and_check(tools=tools,
sample_output=sample_output,
should_match=should_match)
assert all(
[
tool.function.parameters is None or tool.function.parameters == {}
for tool in tools
]
)
_compile_and_check(
tools=tools, sample_output=sample_output, should_match=should_match
)
@pytest.mark.parametrize("output", VALID_TOOLS)
@@ -306,7 +290,7 @@ def test_streaming_output_valid(output, empty_params, delta_len):
function_name_returned = False
messages = []
for i in range(0, len(output_json), delta_len):
delta_text = output_json[i:i + delta_len]
delta_text = output_json[i : i + delta_len]
current_text = previous_text + delta_text
delta_message, function_name_returned = (
@@ -315,7 +299,9 @@ def test_streaming_output_valid(output, empty_params, delta_len):
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned))
function_name_returned=function_name_returned,
)
)
if delta_message:
messages.append(delta_message)
@@ -329,10 +315,12 @@ def test_streaming_output_valid(output, empty_params, delta_len):
if len(combined_messages) > 1:
combined_messages += "},"
combined_messages += '{"name": "' + \
message.tool_calls[0].function.name + \
'", "parameters": ' + \
message.tool_calls[0].function.arguments
combined_messages += (
'{"name": "'
+ message.tool_calls[0].function.name
+ '", "parameters": '
+ message.tool_calls[0].function.arguments
)
else:
combined_messages += message.tool_calls[0].function.arguments
combined_messages += "}]"

View File

@@ -7,9 +7,12 @@ from typing import Optional
import pytest
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
@@ -30,12 +33,14 @@ def xlam_tool_parser(xlam_tokenizer):
return xLAMToolParser(xlam_tokenizer)
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16
@@ -49,8 +54,7 @@ def stream_delta_message_generator(
model_output: str,
request: Optional[ChatCompletionRequest] = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = xlam_tokenizer.encode(model_output,
add_special_tokens=False)
all_token_ids = xlam_tokenizer.encode(model_output, add_special_tokens=False)
previous_text = ""
previous_tokens = None
@@ -59,18 +63,19 @@ def stream_delta_message_generator(
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = (detokenize_incrementally(
tokenizer=xlam_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
))
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=xlam_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
@@ -87,8 +92,9 @@ def stream_delta_message_generator(
yield delta_message
previous_text = current_text
previous_tokens = (previous_tokens +
new_tokens if previous_tokens else new_tokens)
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
@@ -96,7 +102,8 @@ def stream_delta_message_generator(
def test_extract_tool_calls_no_tools(xlam_tool_parser):
model_output = "This is a test"
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@@ -115,87 +122,113 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
(
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}
),
)
),
],
None,
),
(
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"<think>I'll help you with that.</think>",
),
(
"""I'll help you with that.\n```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"I'll help you with that.",
),
(
"""I'll check the weather for you.[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"I'll check the weather for you.",
),
(
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"I'll help you check the weather.",
),
],
)
def test_extract_tool_calls(xlam_tool_parser, model_output,
expected_tool_calls, expected_content):
def test_extract_tool_calls(
xlam_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
@@ -210,25 +243,30 @@ def test_extract_tool_calls(xlam_tool_parser, model_output,
(
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Seattle",
"state": "WA",
"unit": "celsius",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Seattle",
"state": "WA",
"unit": "celsius",
}
),
)
)
],
None,
),
],
)
def test_extract_tool_calls_list_structure(xlam_tool_parser, model_output,
expected_tool_calls,
expected_content):
def test_extract_tool_calls_list_structure(
xlam_tool_parser, model_output, expected_tool_calls, expected_content
):
"""Test extraction of tool calls when the model outputs a list-structured tool call.""" # noqa: E501
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
@@ -239,20 +277,25 @@ def test_extract_tool_calls_list_structure(xlam_tool_parser, model_output,
# Test for preprocess_model_output method
def test_preprocess_model_output(xlam_tool_parser):
# Test with list structure
model_output = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
model_output = (
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
)
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
model_output
)
assert content is None
assert potential_tool_calls == model_output
# Test with thinking tag
model_output = """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
model_output
)
assert content == "<think>I'll help you with that.</think>"
assert (
potential_tool_calls ==
'[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]')
potential_tool_calls
== '[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]'
)
# Test with JSON code block
model_output = """I'll help you with that.
@@ -260,14 +303,16 @@ def test_preprocess_model_output(xlam_tool_parser):
[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]
```"""
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
model_output
)
assert content == "I'll help you with that."
assert "get_current_weather" in potential_tool_calls
# Test with no tool calls
model_output = """I'll help you with that."""
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
model_output
)
assert content == model_output
assert potential_tool_calls is None
@@ -281,7 +326,9 @@ def test_streaming_with_list_structure(xlam_tool_parser):
xlam_tool_parser.current_tool_id = -1
# Simulate receiving a message with list structure
current_text = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
current_text = (
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
)
# First call to set up the tool
xlam_tool_parser.extract_tool_calls_streaming(
@@ -295,8 +342,7 @@ def test_streaming_with_list_structure(xlam_tool_parser):
)
# Make sure the tool is set up correctly
assert (xlam_tool_parser.current_tool_id
>= 0), "Tool index should be initialized"
assert xlam_tool_parser.current_tool_id >= 0, "Tool index should be initialized"
# Manually set up the state for sending the tool name
xlam_tool_parser.current_tools_sent = [False]
@@ -332,78 +378,102 @@ def test_streaming_with_list_structure(xlam_tool_parser):
(
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}
),
)
),
],
"",
),
(
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"<think>I'll help you with that.</think>",
),
(
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"",
),
(
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"",
),
(
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}
),
)
)
],
"I can help with that.",
),
@@ -421,7 +491,8 @@ def test_extract_tool_calls_streaming_incremental(
chunks = []
for delta_message in stream_delta_message_generator(
xlam_tool_parser, xlam_tokenizer, model_output, request):
xlam_tool_parser, xlam_tokenizer, model_output, request
):
chunks.append(delta_message)
# Should have multiple chunks
@@ -433,8 +504,9 @@ def test_extract_tool_calls_streaming_incremental(
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert (chunk.tool_calls[0].function.name ==
expected_first_tool.function.name)
assert (
chunk.tool_calls[0].function.name == expected_first_tool.function.name
)
assert chunk.tool_calls[0].type == "function"
# Arguments may be empty initially or None
if chunk.tool_calls[0].function.arguments is not None:
@@ -446,11 +518,13 @@ def test_extract_tool_calls_streaming_incremental(
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
and chunk.tool_calls[0].function.arguments != ""
and chunk.tool_calls[0].index ==
0 # Only collect arguments from the first tool call
):
if (
chunk.tool_calls
and chunk.tool_calls[0].function.arguments
and chunk.tool_calls[0].function.arguments != ""
and chunk.tool_calls[0].index
== 0 # Only collect arguments from the first tool call
):
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally

View File

@@ -4,8 +4,7 @@
from copy import deepcopy
from typing import Any, Optional
from openai.types.chat import (ChatCompletionMessageParam,
ChatCompletionToolParam)
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
from typing_extensions import TypedDict
from tests.utils import VLLM_PATH
@@ -20,8 +19,9 @@ class ServerConfig(TypedDict, total=False):
extended: Optional[bool] # tests do not run in CI automatically
def patch_system_prompt(messages: list[dict[str, Any]],
system_prompt: str) -> list[dict[str, Any]]:
def patch_system_prompt(
messages: list[dict[str, Any]], system_prompt: str
) -> list[dict[str, Any]]:
new_messages = deepcopy(messages)
if new_messages[0]["role"] == "system":
new_messages[0]["content"] = system_prompt
@@ -30,8 +30,9 @@ def patch_system_prompt(messages: list[dict[str, Any]],
return new_messages
def ensure_system_prompt(messages: list[dict[str, Any]],
config: ServerConfig) -> list[dict[str, Any]]:
def ensure_system_prompt(
messages: list[dict[str, Any]], config: ServerConfig
) -> list[dict[str, Any]]:
prompt = config.get("system_prompt")
if prompt:
return patch_system_prompt(messages, prompt)
@@ -42,92 +43,102 @@ def ensure_system_prompt(messages: list[dict[str, Any]],
# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
ARGS: list[str] = [
"--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs",
"256"
"--enable-auto-tool-choice",
"--max-model-len",
"1024",
"--max-num-seqs",
"256",
]
CONFIGS: dict[str, ServerConfig] = {
"hermes": {
"model":
"NousResearch/Hermes-3-Llama-3.1-8B",
"model": "NousResearch/Hermes-3-Llama-3.1-8B",
"arguments": [
"--enforce-eager", "--no-enable-prefix-caching",
"--tool-call-parser", "hermes", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja")
"--enforce-eager",
"--no-enable-prefix-caching",
"--tool-call-parser",
"hermes",
"--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja"),
],
"system_prompt":
"You are a helpful assistant with access to tools. If a tool"
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
"to the user's question - just respond to it normally.",
},
"llama": {
"model":
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"arguments": [
"--enforce-eager", "--no-enable-prefix-caching",
"--tool-call-parser", "llama3_json", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_llama3.1_json.jinja")
"--enforce-eager",
"--no-enable-prefix-caching",
"--tool-call-parser",
"llama3_json",
"--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_llama3.1_json.jinja"),
],
"supports_parallel":
False,
"supports_parallel": False,
},
"llama3.2": {
"model":
"meta-llama/Llama-3.2-3B-Instruct",
"model": "meta-llama/Llama-3.2-3B-Instruct",
"arguments": [
"--enforce-eager", "--no-enable-prefix-caching",
"--tool-call-parser", "llama3_json", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_llama3.2_json.jinja")
"--enforce-eager",
"--no-enable-prefix-caching",
"--tool-call-parser",
"llama3_json",
"--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_llama3.2_json.jinja"),
],
"supports_parallel":
False,
"supports_parallel": False,
},
"llama4": {
"model":
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"model": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"arguments": [
"--enforce-eager", "--no-enable-prefix-caching",
"--tool-call-parser", "llama4_pythonic", "--chat-template",
str(VLLM_PATH /
"examples/tool_chat_template_llama4_pythonic.jinja"), "-tp",
"4"
"--enforce-eager",
"--no-enable-prefix-caching",
"--tool-call-parser",
"llama4_pythonic",
"--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_llama4_pythonic.jinja"),
"-tp",
"4",
],
"supports_parallel":
False,
"extended":
True
"supports_parallel": False,
"extended": True,
},
"llama4_json": {
"model":
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"model": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"arguments": [
"--enforce-eager", "--no-enable-prefix-caching", "-tp", "4",
"--distributed-executor-backend", "mp", "--tool-call-parser",
"llama4_json", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_llama4_json.jinja")
"--enforce-eager",
"--no-enable-prefix-caching",
"-tp",
"4",
"--distributed-executor-backend",
"mp",
"--tool-call-parser",
"llama4_json",
"--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_llama4_json.jinja"),
],
"supports_parallel":
True,
"extended":
True
"supports_parallel": True,
"extended": True,
},
"mistral": {
"model":
"mistralai/Mistral-7B-Instruct-v0.3",
"model": "mistralai/Mistral-7B-Instruct-v0.3",
"arguments": [
"--enforce-eager", "--no-enable-prefix-caching",
"--tool-call-parser", "mistral", "--chat-template",
"--enforce-eager",
"--no-enable-prefix-caching",
"--tool-call-parser",
"mistral",
"--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"),
"--ignore-patterns=\"consolidated.safetensors\""
'--ignore-patterns="consolidated.safetensors"',
],
"system_prompt":
"You are a helpful assistant with access to tools. If a tool"
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
"to the user's question - just respond to it normally.",
},
# V1 Test: Passing locally but failing in CI. This runs the
# V0 Engine because of CPU offloading. Need to debug why.
@@ -146,49 +157,50 @@ CONFIGS: dict[str, ServerConfig] = {
# False,
# },
"granite-3.0-8b": {
"model":
"ibm-granite/granite-3.0-8b-instruct",
"model": "ibm-granite/granite-3.0-8b-instruct",
"arguments": [
"--enforce-eager", "--no-enable-prefix-caching",
"--tool-call-parser", "granite", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_granite.jinja")
"--enforce-eager",
"--no-enable-prefix-caching",
"--tool-call-parser",
"granite",
"--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_granite.jinja"),
],
},
"granite-3.1-8b": {
"model":
"ibm-granite/granite-3.1-8b-instruct",
"model": "ibm-granite/granite-3.1-8b-instruct",
"arguments": [
"--enforce-eager",
"--no-enable-prefix-caching",
"--tool-call-parser",
"granite",
],
"supports_parallel":
True,
"supports_parallel": True,
},
"internlm": {
"model":
"internlm/internlm2_5-7b-chat",
"model": "internlm/internlm2_5-7b-chat",
"arguments": [
"--enforce-eager", "--no-enable-prefix-caching",
"--tool-call-parser", "internlm", "--chat-template",
str(VLLM_PATH /
"examples/tool_chat_template_internlm2_tool.jinja"),
"--trust_remote_code"
"--enforce-eager",
"--no-enable-prefix-caching",
"--tool-call-parser",
"internlm",
"--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_internlm2_tool.jinja"),
"--trust_remote_code",
],
"supports_parallel":
False,
"supports_parallel": False,
},
"toolACE": {
"model":
"Team-ACE/ToolACE-8B",
"model": "Team-ACE/ToolACE-8B",
"arguments": [
"--enforce-eager", "--no-enable-prefix-caching",
"--tool-call-parser", "pythonic", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_toolace.jinja")
"--enforce-eager",
"--no-enable-prefix-caching",
"--tool-call-parser",
"pythonic",
"--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_toolace.jinja"),
],
"supports_parallel":
True,
"supports_parallel": True,
},
}
@@ -201,37 +213,31 @@ WEATHER_TOOL: ChatCompletionToolParam = {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, "
"e.g. 'San Francisco'"
"type": "string",
"description": "The city to find the weather for, "
"e.g. 'San Francisco'",
},
"state": {
"type":
"string",
"description":
"must the two-letter abbreviation for the state "
"type": "string",
"description": "must the two-letter abbreviation for the state "
"that the city is in, e.g. 'CA' which would "
"mean 'California'"
"mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
}
}
}
"enum": ["celsius", "fahrenheit"],
},
},
},
},
}
SEARCH_TOOL: ChatCompletionToolParam = {
"type": "function",
"function": {
"name":
"web_search",
"description":
"Search the internet and get a summary of the top "
"name": "web_search",
"description": "Search the internet and get a summary of the top "
"10 webpages. Should only be used if you don't know "
"the answer to a user query, and the results are likely"
"to be able to be found with a web search",
@@ -239,124 +245,98 @@ SEARCH_TOOL: ChatCompletionToolParam = {
"type": "object",
"properties": {
"search_term": {
"type":
"string",
"description":
"The term to use in the search. This should"
"type": "string",
"description": "The term to use in the search. This should"
"ideally be keywords to search for, not a"
"natural-language question"
"natural-language question",
}
},
"required": ["search_term"]
}
}
"required": ["search_term"],
},
},
}
MESSAGES_WITHOUT_TOOLS: list[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"Hi! How are you?"
}, {
"role":
"assistant",
"content":
"I'm doing great! How can I assist you?"
}, {
"role":
"user",
"content":
"Can you tell me a joke please?"
}]
MESSAGES_WITHOUT_TOOLS: list[ChatCompletionMessageParam] = [
{"role": "user", "content": "Hi! How are you?"},
{"role": "assistant", "content": "I'm doing great! How can I assist you?"},
{"role": "user", "content": "Can you tell me a joke please?"},
]
MESSAGES_ASKING_FOR_TOOLS: list[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas in Fahrenheit?"
}]
MESSAGES_ASKING_FOR_TOOLS: list[ChatCompletionMessageParam] = [
{"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"}
]
MESSAGES_WITH_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas in Fahrenheit?"
}, {
"role":
"assistant",
"tool_calls": [{
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Dallas", "state": "TX", '
'"unit": "fahrenheit"}'
}
}]
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"content":
"The weather in Dallas is 98 degrees fahrenheit, with partly"
"cloudy skies and a low chance of rain."
}]
MESSAGES_WITH_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [
{"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"},
{
"role": "assistant",
"tool_calls": [
{
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"type": "function",
"function": {
"name": WEATHER_TOOL["function"]["name"],
"arguments": '{"city": "Dallas", "state": "TX", '
'"unit": "fahrenheit"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"content": "The weather in Dallas is 98 degrees fahrenheit, with partly"
"cloudy skies and a low chance of rain.",
},
]
MESSAGES_ASKING_FOR_PARALLEL_TOOLS: list[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas and Orlando, Florida in "
"Fahrenheit?"
}]
MESSAGES_ASKING_FOR_PARALLEL_TOOLS: list[ChatCompletionMessageParam] = [
{
"role": "user",
"content": "What is the weather in Dallas, Texas and Orlando, Florida in "
"Fahrenheit?",
}
]
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{
"role":
"user",
"content":
"What is the weather in Dallas, Texas and Orlando, Florida in "
"Fahrenheit?"
}, {
"role":
"assistant",
"tool_calls": [{
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Dallas", "state": "TX", '
'"unit": "fahrenheit"}'
}
}, {
"id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
"type": "function",
"function": {
"name":
WEATHER_TOOL["function"]["name"],
"arguments":
'{"city": "Orlando", "state": "Fl", '
'"unit": "fahrenheit"}'
}
}]
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"content":
"The weather in Dallas TX is 98 degrees fahrenheit with mostly "
"cloudy skies and a chance of rain in the evening."
}, {
"role":
"tool",
"tool_call_id":
"chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
"content":
"The weather in Orlando FL is 78 degrees fahrenheit with clear"
"skies."
}]
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [
{
"role": "user",
"content": "What is the weather in Dallas, Texas and Orlando, Florida in "
"Fahrenheit?",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"type": "function",
"function": {
"name": WEATHER_TOOL["function"]["name"],
"arguments": '{"city": "Dallas", "state": "TX", '
'"unit": "fahrenheit"}',
},
},
{
"id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
"type": "function",
"function": {
"name": WEATHER_TOOL["function"]["name"],
"arguments": '{"city": "Orlando", "state": "Fl", '
'"unit": "fahrenheit"}',
},
},
],
},
{
"role": "tool",
"tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
"content": "The weather in Dallas TX is 98 degrees fahrenheit with mostly "
"cloudy skies and a chance of rain in the evening.",
},
{
"role": "tool",
"tool_call_id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
"content": "The weather in Orlando FL is 78 degrees fahrenheit with clear"
"skies.",
},
]