Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -6,8 +6,7 @@ import json
import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import (
Hermes2ProToolParser)
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from ....utils import RemoteOpenAIServer
@@ -27,61 +26,64 @@ SERVER_ARGS = [
f"{LORA_MODEL}",
]
TOOLS = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description":
"The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
TOOLS = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
"required": ["location"],
},
},
}]
}
]
PRODUCT_TOOLS = [{
"type": "function",
"function": {
"name": "get_product_info",
"description": "Get detailed information of a product based on its "
"product ID.",
"parameters": {
"type": "object",
"properties": {
"inserted": {
"type": "boolean",
"description": "inserted.",
},
"product_id": {
"type": "integer",
"description": "The product ID of the product.",
PRODUCT_TOOLS = [
{
"type": "function",
"function": {
"name": "get_product_info",
"description": "Get detailed information of a product based on its "
"product ID.",
"parameters": {
"type": "object",
"properties": {
"inserted": {
"type": "boolean",
"description": "inserted.",
},
"product_id": {
"type": "integer",
"description": "The product ID of the product.",
},
},
"required": ["product_id", "inserted"],
},
"required": ["product_id", "inserted"],
},
},
}]
}
]
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
PRODUCT_MESSAGES = [{
"role":
"user",
"content":
"Hi! Do you have any detailed information about the product id "
"7355608 and inserted true?",
}]
PRODUCT_MESSAGES = [
{
"role": "user",
"content": "Hi! Do you have any detailed information about the product id "
"7355608 and inserted true?",
}
]
@pytest.mark.asyncio
@@ -150,7 +152,8 @@ async def test_streaming_tool_call():
tool_call_chunks[index]["name"] += tool_chunk.function.name
if tool_chunk.function.arguments:
tool_call_chunks[index]["arguments"] += (
tool_chunk.function.arguments)
tool_chunk.function.arguments
)
assert len(tool_call_chunks) == 1
reconstructed_tool_call = tool_call_chunks[0]
@@ -240,7 +243,8 @@ async def test_streaming_product_tool_call():
tool_call_chunks[index]["name"] += tool_chunk.function.name
if tool_chunk.function.arguments:
tool_call_chunks[index]["arguments"] += (
tool_chunk.function.arguments)
tool_chunk.function.arguments
)
assert len(tool_call_chunks) == 1
reconstructed_tool_call = tool_call_chunks[0]
@@ -291,9 +295,7 @@ def test_hermes_parser_streaming_just_forward_text(
hermes_parser: Hermes2ProToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = (
"""This is some prior text that has nothing to do with tool calling."""
)
text = """This is some prior text that has nothing to do with tool calling."""
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
@@ -348,8 +350,9 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
delta_messages.append(delta)
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
for delta in delta_messages)
tool_call_args = "".join(
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
)
assert tool_call_args == '{"trigger": true}'
@@ -383,13 +386,13 @@ def test_hermes_parser_streaming(
if delta is not None:
delta_messages.append(delta)
print(delta_messages)
assert (delta_messages[0].tool_calls[0].function.name ==
"get_current_temperature")
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
for delta in delta_messages)
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
tool_call_args = "".join(
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
)
assert tool_call_args == (
'{"location":"San Francisco, California, United States", '
'"unit": "celsius"}')
'{"location":"San Francisco, California, United States", "unit": "celsius"}'
)
def test_hermes_parser_non_streaming_no_tool_call(

View File

@@ -8,15 +8,18 @@ from unittest.mock import MagicMock
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction, run_tool_extraction_streaming)
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
def make_tool_call(name, arguments):
return ToolCall(type="function",
function=FunctionCall(name=name,
arguments=json.dumps(arguments)))
return ToolCall(
type="function",
function=FunctionCall(name=name, arguments=json.dumps(arguments)),
)
# TODO: add reason prefix and suffix.
@@ -29,70 +32,68 @@ def make_tool_call(name, arguments):
("How can I help you today?", [], "How can I help you today?"),
# Single tool call, no content
(
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]</tool_calls>", #noqa: E501
[
make_tool_call("get_weather", {
"city": "San Francisco",
"metric": "celsius"
})
],
None),
# Multiple tool calls
(
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]</tool_calls>", #noqa: E501
[
make_tool_call("get_weather", {
"city": "San Francisco",
"metric": "celsius"
}),
make_tool_call(
"register_user", {
"name": "John Doe",
"age": 37,
"address": {
"city": "San Francisco",
"state": "CA"
},
"role": None,
"passed_test": True,
"aliases": ["John", "Johnny"]
})
],
None),
# Content before tool call
(
"I will call the tool now. <tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]</tool_calls>", #noqa: E501
[make_tool_call("get_weather", {"city": "Boston"})],
"I will call the tool now. "),
# Content after tool call (should be stripped)
(
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]</tool_calls>\nThank you!", #noqa: E501
[make_tool_call("get_weather", {"city": "Seattle"})],
None),
(
"<tool_calls>[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]</tool_calls>",
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]</tool_calls>', # noqa: E501
[
make_tool_call(
"complex_tool",
{"level1": {
"level2": {
"level3": {
"value": 123
}
}
}})
"get_weather", {"city": "San Francisco", "metric": "celsius"}
)
],
None,
),
])
def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
expected_content):
# Multiple tool calls
(
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]</tool_calls>', # noqa: E501
[
make_tool_call(
"get_weather", {"city": "San Francisco", "metric": "celsius"}
),
make_tool_call(
"register_user",
{
"name": "John Doe",
"age": 37,
"address": {"city": "San Francisco", "state": "CA"},
"role": None,
"passed_test": True,
"aliases": ["John", "Johnny"],
},
),
],
None,
),
# Content before tool call
(
'I will call the tool now. <tool_calls>[{"name": "get_weather", "arguments": {"city": "Boston"}}]</tool_calls>', # noqa: E501
[make_tool_call("get_weather", {"city": "Boston"})],
"I will call the tool now. ",
),
# Content after tool call (should be stripped)
(
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "Seattle"}}]</tool_calls>\nThank you!', # noqa: E501
[make_tool_call("get_weather", {"city": "Seattle"})],
None,
),
(
'<tool_calls>[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]</tool_calls>',
[
make_tool_call(
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
)
],
None,
),
],
)
def test_hunyuan_a13b_tool_parser_extract(
model_output, expected_tool_calls, expected_content
):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
"hunyuan_a13b")(mock_tokenizer)
content, tool_calls = run_tool_extraction(tool_parser,
model_output,
streaming=False)
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
mock_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=False
)
# align the random id.
for idx in range(len(tool_calls)):
@@ -102,49 +103,74 @@ def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
# Streaming test: simulate incremental output
@pytest.mark.parametrize("model_deltas,expected_tool_calls", [
([
"<tool_calls>[{\"name\": \"get_weather\", ",
"\"arguments\": {\"city\": \"San Francisco\", ",
"\"metric\": \"celsius\"}}]", "</tool_calls>"
], [
make_tool_call("get_weather", {
"city": "San Francisco",
"metric": "celsius"
})
]),
([
"<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
" {\"city\": \"Boston\"}", "}]", "</tool_calls>"
], [make_tool_call("get_weather", {"city": "Boston"})]),
([
"", "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
" {\"city\": \"Boston\"}", "}]", "</tool_calls>", "\n</answer>"
], [make_tool_call("get_weather", {"city": "Boston"})]),
pytest.param([
"<tool_calls>[{\"name\": \"complex_tool\",", " \"arguments\": ",
" {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}",
"]</tool_calls>"
], [
make_tool_call("complex_tool",
{"level1": {
"level2": {
"level3": {
"value": 123
}
}
}})
@pytest.mark.parametrize(
"model_deltas,expected_tool_calls",
[
(
[
'<tool_calls>[{"name": "get_weather", ',
'"arguments": {"city": "San Francisco", ',
'"metric": "celsius"}}]',
"</tool_calls>",
],
[
make_tool_call(
"get_weather", {"city": "San Francisco", "metric": "celsius"}
)
],
),
(
[
'<tool_calls>[{"name":',
' "get_weather",',
' "arguments":',
' {"city": "Boston"}',
"}]",
"</tool_calls>",
],
[make_tool_call("get_weather", {"city": "Boston"})],
),
(
[
"",
'<tool_calls>[{"name":',
' "get_weather",',
' "arguments":',
' {"city": "Boston"}',
"}]",
"</tool_calls>",
"\n</answer>",
],
[make_tool_call("get_weather", {"city": "Boston"})],
),
pytest.param(
[
'<tool_calls>[{"name": "complex_tool",',
' "arguments": ',
' {"level1": {"level2": ',
'{"level3": {"value": 123}}}}}',
"]</tool_calls>",
],
[
make_tool_call(
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
)
],
marks=pytest.mark.xfail(
reason="stream parsing not support nested json yet."
),
),
],
marks=pytest.mark.xfail(
reason="stream parsing not support nested json yet.")),
])
)
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
"hunyuan_a13b")(mock_tokenizer)
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
mock_tokenizer
)
reconstructor = run_tool_extraction_streaming(
tool_parser, model_deltas, assert_one_tool_per_delta=False)
tool_parser, model_deltas, assert_one_tool_per_delta=False
)
# align the random id.
for idx in range(len(reconstructor.tool_calls)):

View File

@@ -5,8 +5,7 @@ import pytest
from transformers import AutoTokenizer
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import (
Llama3JsonToolParser)
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
@pytest.fixture
@@ -18,8 +17,10 @@ def parser():
def test_extract_tool_calls_simple(parser):
# Test with a simple tool call
model_output = ('Here is the result: {"name": "getOpenIncidentsTool", '
'"parameters": {}} Would you like to know more?')
model_output = (
'Here is the result: {"name": "getOpenIncidentsTool", '
'"parameters": {}} Would you like to know more?'
)
result = parser.extract_tool_calls(model_output, None)
assert isinstance(result, ExtractedToolCallInformation)
@@ -34,8 +35,8 @@ def test_extract_tool_calls_simple(parser):
def test_extract_tool_calls_with_arguments(parser):
# Test with a tool call that has arguments
model_output = (
'{"name": "searchTool", "parameters": {"query": "test query", '
'"limit": 10}}')
'{"name": "searchTool", "parameters": {"query": "test query", "limit": 10}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
@@ -81,7 +82,8 @@ def test_extract_tool_calls_multiple_json(parser):
model_output = (
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
'{"name": "searchTool", "parameters": {"query": "test2"}}')
'{"name": "searchTool", "parameters": {"query": "test2"}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
@@ -105,7 +107,8 @@ def test_extract_tool_calls_multiple_json_with_whitespace(parser):
model_output = (
'{"name": "searchTool", "parameters": {"query": "test1"}} ; '
'{"name": "getOpenIncidentsTool", "parameters": {}} ; '
'{"name": "searchTool", "parameters": {"query": "test2"}}')
'{"name": "searchTool", "parameters": {"query": "test2"}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
@@ -118,11 +121,12 @@ def test_extract_tool_calls_multiple_json_with_whitespace(parser):
def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
# Test with multiple JSONs and surrounding text
model_output = (
'Here are the results: '
"Here are the results: "
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
'{"name": "searchTool", "parameters": {"query": "test2"}} '
'Would you like to know more?')
"Would you like to know more?"
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True

View File

@@ -6,7 +6,9 @@ from unittest.mock import MagicMock, patch
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction, run_tool_extraction_streaming)
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
@@ -16,12 +18,14 @@ SIMPLE_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "LA", "metric": "C"}',
)
MORE_TYPES_FUNCTION_OUTPUT = ("[register_user(name='Doe', "
"age=9, "
"address={'city': 'LA', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])]")
MORE_TYPES_FUNCTION_OUTPUT = (
"[register_user(name='Doe', "
"age=9, "
"address={'city': 'LA', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])]"
)
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "Doe", '
@@ -34,7 +38,7 @@ MORE_TYPES_FUNCTION_CALL = FunctionCall(
PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{}',
arguments="{}",
)
EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
@@ -47,25 +51,28 @@ EMPTY_LIST_FUNCTION_CALL = FunctionCall(
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]")
r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]"
)
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
)
PYTHON_TAG_FUNCTION_OUTPUT = (
"<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>")
"<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>"
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
"llama4_pythonic")(mock_tokenizer)
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
mock_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(tool_parser,
model_output,
streaming=streaming)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
@@ -75,98 +82,139 @@ test_str = "<|python_start|>"
test_str += "[get_weather(city='LA', metric='C'),"
test_str += "register_user(name='Doe', age=9)]"
TEST_CASES = [
pytest.param(True,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="simple_streaming"),
pytest.param(False,
SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
id="simple_nonstreaming"),
pytest.param(True,
MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming"),
pytest.param(False,
MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming"),
pytest.param(True,
PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming"),
pytest.param(False,
PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming"),
pytest.param(True,
EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming"),
pytest.param(False,
EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming"),
pytest.param(True,
EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming"),
pytest.param(False,
EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming"),
pytest.param(True,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming"),
pytest.param(False,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming"),
pytest.param(
True,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="simple_streaming",
),
pytest.param(
False, SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], id="simple_nonstreaming"
),
pytest.param(
True,
MORE_TYPES_FUNCTION_OUTPUT,
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming",
),
pytest.param(
False,
MORE_TYPES_FUNCTION_OUTPUT,
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming",
),
pytest.param(
True,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming",
),
pytest.param(
False,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming",
),
pytest.param(
True,
EMPTY_DICT_FUNCTION_OUTPUT,
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming",
),
pytest.param(
False,
EMPTY_DICT_FUNCTION_OUTPUT,
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming",
),
pytest.param(
True,
EMPTY_LIST_FUNCTION_OUTPUT,
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming",
),
pytest.param(
False,
EMPTY_LIST_FUNCTION_OUTPUT,
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming",
),
pytest.param(
True,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming",
),
pytest.param(
False,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming",
),
pytest.param(
True,
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user",
arguments='{"name": "Doe", "age": 9}')
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_streaming"),
id="parallel_calls_streaming",
),
pytest.param(
False,
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user",
arguments='{"name": "Doe", "age": 9}')
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_nonstreaming"),
pytest.param(True,
PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
id="python_tag_streaming"),
pytest.param(False,
PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
id="python_tag_nonstreaming"),
pytest.param(True,
test_str, [
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user",
arguments='{"name": "Doe", "age": 9}')
],
id="parallel_calls_streaming"),
pytest.param(False,
"<|python_start|>[get_weather(city='LA', metric='C'), " +
"register_user(name='Doe', age=9)]", [
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user",
arguments='{"name": "Doe", "age": 9}')
],
id="parallel_calls_nonstreaming"),
id="parallel_calls_nonstreaming",
),
pytest.param(
True,
PYTHON_TAG_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
id="python_tag_streaming",
),
pytest.param(
False,
PYTHON_TAG_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
id="python_tag_nonstreaming",
),
pytest.param(
True,
test_str,
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_streaming",
),
pytest.param(
False,
"<|python_start|>[get_weather(city='LA', metric='C'), "
+ "register_user(name='Doe', age=9)]",
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_nonstreaming",
),
]
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
TEST_CASES)
def test_tool_call(streaming: bool, model_output: str,
expected_tool_calls: list[FunctionCall]):
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
def test_tool_call(
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
"llama4_pythonic")(mock_tokenizer)
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
mock_tokenizer
)
content, tool_calls = run_tool_extraction(tool_parser,
model_output,
streaming=streaming)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
@@ -176,8 +224,9 @@ def test_tool_call(streaming: bool, model_output: str,
def test_streaming_tool_call_with_large_steps():
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
"llama4_pythonic")(mock_tokenizer)
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
mock_tokenizer
)
model_output_deltas = [
"<|python_start|>[get_weather(city='LA', metric='C'), "
"get_weather(), "
@@ -185,7 +234,8 @@ def test_streaming_tool_call_with_large_steps():
]
reconstructor = run_tool_extraction_streaming(
tool_parser, model_output_deltas, assert_one_tool_per_delta=False)
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
)
assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
@@ -198,8 +248,9 @@ def test_streaming_tool_call_with_large_steps():
def test_regex_timeout_handling(streaming: bool):
"""test regex timeout is handled gracefully"""
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
"llama4_pythonic")(mock_tokenizer)
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
mock_tokenizer
)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
@@ -207,10 +258,10 @@ def test_regex_timeout_handling(streaming: bool):
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
content, tool_calls = run_tool_extraction(tool_parser,
fake_problematic_input,
streaming=streaming)
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
content, tool_calls = run_tool_extraction(
tool_parser, fake_problematic_input, streaming=streaming
)
# should treat as regular text when regex times out
assert content == fake_problematic_input

View File

@@ -6,7 +6,9 @@ from unittest.mock import MagicMock, patch
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction, run_tool_extraction_streaming)
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
@@ -22,7 +24,8 @@ MORE_TYPES_FUNCTION_OUTPUT = (
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])")
"aliases=['John', 'Johnny'])"
)
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "John Doe", '
@@ -35,7 +38,7 @@ MORE_TYPES_FUNCTION_CALL = FunctionCall(
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{}',
arguments="{}",
)
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
@@ -48,7 +51,8 @@ EMPTY_LIST_FUNCTION_CALL = FunctionCall(
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')")
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
)
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
@@ -59,80 +63,118 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
def test_no_tool_call(streaming: bool):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)
mock_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(tool_parser,
model_output,
streaming=streaming)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(True,
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
id="simple_streaming"),
pytest.param(False,
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
id="simple_nonstreaming"),
pytest.param(True,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming"),
pytest.param(False,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming"),
pytest.param(True,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming"),
pytest.param(False,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming"),
pytest.param(True,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming"),
pytest.param(False,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming"),
pytest.param(True,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming"),
pytest.param(False,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming"),
pytest.param(True,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming"),
pytest.param(False,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming"),
pytest.param(True,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_streaming"),
pytest.param(False,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_nonstreaming"),
pytest.param(
True,
f"[{SIMPLE_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL],
id="simple_streaming",
),
pytest.param(
False,
f"[{SIMPLE_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL],
id="simple_nonstreaming",
),
pytest.param(
True,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming",
),
pytest.param(
False,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming",
),
pytest.param(
True,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming",
),
pytest.param(
False,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming",
),
pytest.param(
True,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming",
),
pytest.param(
False,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming",
),
pytest.param(
True,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming",
),
pytest.param(
False,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming",
),
pytest.param(
True,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming",
),
pytest.param(
False,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming",
),
pytest.param(
True,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_streaming",
),
pytest.param(
False,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_nonstreaming",
),
]
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
TEST_CASES)
def test_tool_call(streaming: bool, model_output: str,
expected_tool_calls: list[FunctionCall]):
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
def test_tool_call(
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)
mock_tokenizer
)
content, tool_calls = run_tool_extraction(tool_parser,
model_output,
streaming=streaming)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content is None
assert len(tool_calls) == len(expected_tool_calls)
@@ -144,7 +186,8 @@ def test_tool_call(streaming: bool, model_output: str,
def test_streaming_tool_call_with_large_steps():
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)
mock_tokenizer
)
model_output_deltas = [
"[get_weather(city='San",
" Francisco', metric='celsius'), "
@@ -153,7 +196,8 @@ def test_streaming_tool_call_with_large_steps():
]
reconstructor = run_tool_extraction_streaming(
tool_parser, model_output_deltas, assert_one_tool_per_delta=False)
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
)
assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
@@ -166,8 +210,9 @@ def test_streaming_tool_call_with_large_steps():
def test_regex_timeout_handling(streaming: bool):
"""test regex timeout is handled gracefully"""
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
"llama4_pythonic")(mock_tokenizer)
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
mock_tokenizer
)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
@@ -175,10 +220,10 @@ def test_regex_timeout_handling(streaming: bool):
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
content, tool_calls = run_tool_extraction(tool_parser,
fake_problematic_input,
streaming=streaming)
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
content, tool_calls = run_tool_extraction(
tool_parser, fake_problematic_input, streaming=streaming
)
# should treat as regular text when regex times out
assert content == fake_problematic_input

View File

@@ -4,15 +4,17 @@
from collections.abc import Iterable
from typing import Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers import ToolParser
class StreamingToolReconstructor:
def __init__(self, assert_one_tool_per_delta: bool = True):
self.tool_calls: list[ToolCall] = []
self.other_content: str = ""
@@ -23,49 +25,60 @@ class StreamingToolReconstructor:
self.other_content += delta.content
else:
assert delta.tool_calls, (
"Streaming results should have either content or tool calls "
"(or both)")
"Streaming results should have either content or tool calls (or both)"
)
if self._assert_one_tool_per_delta:
# Note: This isn't strictly required by the API and may not be
# possible to adhere to depending on the token space and number of
# tokens per streamed response from the model, but it is required
# by tool_use tests, so we enforce it here by default also.
assert len(delta.tool_calls) < 2, (
"Streaming should include only one tool call per update.")
"Streaming should include only one tool call per update."
)
for call_delta in delta.tool_calls:
assert call_delta.type is None or call_delta.type == "function", (
"Streaming tool calls should only emit function calls. Got "
f"{call_delta.type}")
current_tool_call = self.tool_calls[
call_delta.index] if call_delta.index < len(
self.tool_calls) else None
f"{call_delta.type}"
)
current_tool_call = (
self.tool_calls[call_delta.index]
if call_delta.index < len(self.tool_calls)
else None
)
if current_tool_call:
assert (not call_delta.function.name), (
assert not call_delta.function.name, (
"Streaming tool calls should emit the full function name "
f"exactly once. Got {call_delta.function.name}")
assert (not call_delta.id), (
f"exactly once. Got {call_delta.function.name}"
)
assert not call_delta.id, (
"Streaming tool calls must emit function id only once. Got "
f"{call_delta.id}")
assert (call_delta.index == len(self.tool_calls) - 1), (
f"{call_delta.id}"
)
assert call_delta.index == len(self.tool_calls) - 1, (
f"Incorrect index for tool delta. Got {call_delta.index}, "
f"expected {len(self.tool_calls) - 1}")
current_tool_call.function.arguments += (
call_delta.function.arguments)
f"expected {len(self.tool_calls) - 1}"
)
current_tool_call.function.arguments += call_delta.function.arguments
else:
assert call_delta.id is not None, (
"Streaming tool calls must have an id on first appearance")
"Streaming tool calls must have an id on first appearance"
)
assert call_delta.function.name is not None, (
"Streaming tool calls must have a function name on first "
"appearance")
"Streaming tool calls must have a function name on first appearance"
)
assert call_delta.index == len(self.tool_calls), (
f"Incorrect index for tool delta. Got {call_delta.index}, "
f"expected {len(self.tool_calls)}")
f"expected {len(self.tool_calls)}"
)
self.tool_calls.append(
ToolCall(id=call_delta.id,
function=FunctionCall(
name=call_delta.function.name,
arguments=call_delta.function.arguments
or "")))
ToolCall(
id=call_delta.id,
function=FunctionCall(
name=call_delta.function.name,
arguments=call_delta.function.arguments or "",
),
)
)
def run_tool_extraction(
@@ -80,11 +93,11 @@ def run_tool_extraction(
tool_parser,
model_output,
request,
assert_one_tool_per_delta=assert_one_tool_per_delta)
assert_one_tool_per_delta=assert_one_tool_per_delta,
)
return reconstructor.other_content or None, reconstructor.tool_calls
else:
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output,
request)
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, request)
assert extracted.tools_called == bool(extracted.tool_calls)
return extracted.content, extracted.tool_calls
@@ -92,7 +105,7 @@ def run_tool_extraction(
def run_tool_extraction_nonstreaming(
tool_parser: ToolParser,
model_output: str,
request: Union[ChatCompletionRequest, None] = None
request: Union[ChatCompletionRequest, None] = None,
) -> ExtractedToolCallInformation:
request = request or ChatCompletionRequest(messages=[], model="test-model")
return tool_parser.extract_tool_calls(model_output, request)
@@ -106,7 +119,8 @@ def run_tool_extraction_streaming(
) -> StreamingToolReconstructor:
request = request or ChatCompletionRequest(messages=[], model="test-model")
reconstructor = StreamingToolReconstructor(
assert_one_tool_per_delta=assert_one_tool_per_delta)
assert_one_tool_per_delta=assert_one_tool_per_delta
)
previous_text = ""
previous_tokens: list[int] = []
for delta in model_deltas:
@@ -118,8 +132,14 @@ def run_tool_extraction_streaming(
current_text = previous_text + delta
current_tokens = previous_tokens + token_delta
delta_message = tool_parser.extract_tool_calls_streaming(
previous_text, current_text, delta, previous_tokens,
current_tokens, token_delta, request)
previous_text,
current_text,
delta,
previous_tokens,
current_tokens,
token_delta,
request,
)
if delta_message is not None:
reconstructor.append_delta(delta_message)
previous_text = current_text