Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,8 +6,7 @@ import json
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import (
|
||||
Hermes2ProToolParser)
|
||||
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
@@ -27,61 +26,64 @@ SERVER_ARGS = [
|
||||
f"{LORA_MODEL}",
|
||||
]
|
||||
|
||||
TOOLS = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
PRODUCT_TOOLS = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_product_info",
|
||||
"description": "Get detailed information of a product based on its "
|
||||
"product ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"inserted": {
|
||||
"type": "boolean",
|
||||
"description": "inserted.",
|
||||
},
|
||||
"product_id": {
|
||||
"type": "integer",
|
||||
"description": "The product ID of the product.",
|
||||
PRODUCT_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_product_info",
|
||||
"description": "Get detailed information of a product based on its "
|
||||
"product ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"inserted": {
|
||||
"type": "boolean",
|
||||
"description": "inserted.",
|
||||
},
|
||||
"product_id": {
|
||||
"type": "integer",
|
||||
"description": "The product ID of the product.",
|
||||
},
|
||||
},
|
||||
"required": ["product_id", "inserted"],
|
||||
},
|
||||
"required": ["product_id", "inserted"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
}
|
||||
]
|
||||
|
||||
MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
|
||||
|
||||
PRODUCT_MESSAGES = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Hi! Do you have any detailed information about the product id "
|
||||
"7355608 and inserted true?",
|
||||
}]
|
||||
PRODUCT_MESSAGES = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! Do you have any detailed information about the product id "
|
||||
"7355608 and inserted true?",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -150,7 +152,8 @@ async def test_streaming_tool_call():
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments)
|
||||
tool_chunk.function.arguments
|
||||
)
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
@@ -240,7 +243,8 @@ async def test_streaming_product_tool_call():
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments)
|
||||
tool_chunk.function.arguments
|
||||
)
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
@@ -291,9 +295,7 @@ def test_hermes_parser_streaming_just_forward_text(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = (
|
||||
"""This is some prior text that has nothing to do with tool calling."""
|
||||
)
|
||||
text = """This is some prior text that has nothing to do with tool calling."""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
@@ -348,8 +350,9 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
delta_messages.append(delta)
|
||||
|
||||
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
|
||||
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
|
||||
for delta in delta_messages)
|
||||
tool_call_args = "".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
assert tool_call_args == '{"trigger": true}'
|
||||
|
||||
|
||||
@@ -383,13 +386,13 @@ def test_hermes_parser_streaming(
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
print(delta_messages)
|
||||
assert (delta_messages[0].tool_calls[0].function.name ==
|
||||
"get_current_temperature")
|
||||
tool_call_args = "".join(delta.tool_calls[0].function.arguments or ""
|
||||
for delta in delta_messages)
|
||||
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
|
||||
tool_call_args = "".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
assert tool_call_args == (
|
||||
'{"location":"San Francisco, California, United States", '
|
||||
'"unit": "celsius"}')
|
||||
'{"location":"San Francisco, California, United States", "unit": "celsius"}'
|
||||
)
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_no_tool_call(
|
||||
|
||||
@@ -8,15 +8,18 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction, run_tool_extraction_streaming)
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
|
||||
def make_tool_call(name, arguments):
|
||||
return ToolCall(type="function",
|
||||
function=FunctionCall(name=name,
|
||||
arguments=json.dumps(arguments)))
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=name, arguments=json.dumps(arguments)),
|
||||
)
|
||||
|
||||
|
||||
# TODO: add reason prefix and suffix.
|
||||
@@ -29,70 +32,68 @@ def make_tool_call(name, arguments):
|
||||
("How can I help you today?", [], "How can I help you today?"),
|
||||
# Single tool call, no content
|
||||
(
|
||||
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]</tool_calls>", #noqa: E501
|
||||
[
|
||||
make_tool_call("get_weather", {
|
||||
"city": "San Francisco",
|
||||
"metric": "celsius"
|
||||
})
|
||||
],
|
||||
None),
|
||||
# Multiple tool calls
|
||||
(
|
||||
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]</tool_calls>", #noqa: E501
|
||||
[
|
||||
make_tool_call("get_weather", {
|
||||
"city": "San Francisco",
|
||||
"metric": "celsius"
|
||||
}),
|
||||
make_tool_call(
|
||||
"register_user", {
|
||||
"name": "John Doe",
|
||||
"age": 37,
|
||||
"address": {
|
||||
"city": "San Francisco",
|
||||
"state": "CA"
|
||||
},
|
||||
"role": None,
|
||||
"passed_test": True,
|
||||
"aliases": ["John", "Johnny"]
|
||||
})
|
||||
],
|
||||
None),
|
||||
# Content before tool call
|
||||
(
|
||||
"I will call the tool now. <tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]</tool_calls>", #noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
"I will call the tool now. "),
|
||||
# Content after tool call (should be stripped)
|
||||
(
|
||||
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]</tool_calls>\nThank you!", #noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Seattle"})],
|
||||
None),
|
||||
(
|
||||
"<tool_calls>[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]</tool_calls>",
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]</tool_calls>', # noqa: E501
|
||||
[
|
||||
make_tool_call(
|
||||
"complex_tool",
|
||||
{"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"value": 123
|
||||
}
|
||||
}
|
||||
}})
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
])
|
||||
def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
|
||||
expected_content):
|
||||
# Multiple tool calls
|
||||
(
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]</tool_calls>', # noqa: E501
|
||||
[
|
||||
make_tool_call(
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
),
|
||||
make_tool_call(
|
||||
"register_user",
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 37,
|
||||
"address": {"city": "San Francisco", "state": "CA"},
|
||||
"role": None,
|
||||
"passed_test": True,
|
||||
"aliases": ["John", "Johnny"],
|
||||
},
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
# Content before tool call
|
||||
(
|
||||
'I will call the tool now. <tool_calls>[{"name": "get_weather", "arguments": {"city": "Boston"}}]</tool_calls>', # noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
"I will call the tool now. ",
|
||||
),
|
||||
# Content after tool call (should be stripped)
|
||||
(
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "Seattle"}}]</tool_calls>\nThank you!', # noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Seattle"})],
|
||||
None,
|
||||
),
|
||||
(
|
||||
'<tool_calls>[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]</tool_calls>',
|
||||
[
|
||||
make_tool_call(
|
||||
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_hunyuan_a13b_tool_parser_extract(
|
||||
model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"hunyuan_a13b")(mock_tokenizer)
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=False)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
|
||||
mock_tokenizer
|
||||
)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=False
|
||||
)
|
||||
|
||||
# align the random id.
|
||||
for idx in range(len(tool_calls)):
|
||||
@@ -102,49 +103,74 @@ def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
|
||||
|
||||
|
||||
# Streaming test: simulate incremental output
|
||||
@pytest.mark.parametrize("model_deltas,expected_tool_calls", [
|
||||
([
|
||||
"<tool_calls>[{\"name\": \"get_weather\", ",
|
||||
"\"arguments\": {\"city\": \"San Francisco\", ",
|
||||
"\"metric\": \"celsius\"}}]", "</tool_calls>"
|
||||
], [
|
||||
make_tool_call("get_weather", {
|
||||
"city": "San Francisco",
|
||||
"metric": "celsius"
|
||||
})
|
||||
]),
|
||||
([
|
||||
"<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
|
||||
" {\"city\": \"Boston\"}", "}]", "</tool_calls>"
|
||||
], [make_tool_call("get_weather", {"city": "Boston"})]),
|
||||
([
|
||||
"", "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
|
||||
" {\"city\": \"Boston\"}", "}]", "</tool_calls>", "\n</answer>"
|
||||
], [make_tool_call("get_weather", {"city": "Boston"})]),
|
||||
pytest.param([
|
||||
"<tool_calls>[{\"name\": \"complex_tool\",", " \"arguments\": ",
|
||||
" {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}",
|
||||
"]</tool_calls>"
|
||||
], [
|
||||
make_tool_call("complex_tool",
|
||||
{"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"value": 123
|
||||
}
|
||||
}
|
||||
}})
|
||||
@pytest.mark.parametrize(
|
||||
"model_deltas,expected_tool_calls",
|
||||
[
|
||||
(
|
||||
[
|
||||
'<tool_calls>[{"name": "get_weather", ',
|
||||
'"arguments": {"city": "San Francisco", ',
|
||||
'"metric": "celsius"}}]',
|
||||
"</tool_calls>",
|
||||
],
|
||||
[
|
||||
make_tool_call(
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
[
|
||||
'<tool_calls>[{"name":',
|
||||
' "get_weather",',
|
||||
' "arguments":',
|
||||
' {"city": "Boston"}',
|
||||
"}]",
|
||||
"</tool_calls>",
|
||||
],
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
),
|
||||
(
|
||||
[
|
||||
"",
|
||||
'<tool_calls>[{"name":',
|
||||
' "get_weather",',
|
||||
' "arguments":',
|
||||
' {"city": "Boston"}',
|
||||
"}]",
|
||||
"</tool_calls>",
|
||||
"\n</answer>",
|
||||
],
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
'<tool_calls>[{"name": "complex_tool",',
|
||||
' "arguments": ',
|
||||
' {"level1": {"level2": ',
|
||||
'{"level3": {"value": 123}}}}}',
|
||||
"]</tool_calls>",
|
||||
],
|
||||
[
|
||||
make_tool_call(
|
||||
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(
|
||||
reason="stream parsing not support nested json yet."
|
||||
),
|
||||
),
|
||||
],
|
||||
marks=pytest.mark.xfail(
|
||||
reason="stream parsing not support nested json yet.")),
|
||||
])
|
||||
)
|
||||
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
|
||||
mock_tokenizer = MagicMock()
|
||||
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"hunyuan_a13b")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
|
||||
mock_tokenizer
|
||||
)
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_deltas, assert_one_tool_per_delta=False)
|
||||
tool_parser, model_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
# align the random id.
|
||||
for idx in range(len(reconstructor.tool_calls)):
|
||||
|
||||
@@ -5,8 +5,7 @@ import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
|
||||
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import (
|
||||
Llama3JsonToolParser)
|
||||
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -18,8 +17,10 @@ def parser():
|
||||
|
||||
def test_extract_tool_calls_simple(parser):
|
||||
# Test with a simple tool call
|
||||
model_output = ('Here is the result: {"name": "getOpenIncidentsTool", '
|
||||
'"parameters": {}} Would you like to know more?')
|
||||
model_output = (
|
||||
'Here is the result: {"name": "getOpenIncidentsTool", '
|
||||
'"parameters": {}} Would you like to know more?'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert isinstance(result, ExtractedToolCallInformation)
|
||||
@@ -34,8 +35,8 @@ def test_extract_tool_calls_simple(parser):
|
||||
def test_extract_tool_calls_with_arguments(parser):
|
||||
# Test with a tool call that has arguments
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test query", '
|
||||
'"limit": 10}}')
|
||||
'{"name": "searchTool", "parameters": {"query": "test query", "limit": 10}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
@@ -81,7 +82,8 @@ def test_extract_tool_calls_multiple_json(parser):
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}')
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
@@ -105,7 +107,8 @@ def test_extract_tool_calls_multiple_json_with_whitespace(parser):
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}} ; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}} ; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}')
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
@@ -118,11 +121,12 @@ def test_extract_tool_calls_multiple_json_with_whitespace(parser):
|
||||
def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
|
||||
# Test with multiple JSONs and surrounding text
|
||||
model_output = (
|
||||
'Here are the results: '
|
||||
"Here are the results: "
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}} '
|
||||
'Would you like to know more?')
|
||||
"Would you like to know more?"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
|
||||
@@ -6,7 +6,9 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction, run_tool_extraction_streaming)
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
@@ -16,12 +18,14 @@ SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "LA", "metric": "C"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = ("[register_user(name='Doe', "
|
||||
"age=9, "
|
||||
"address={'city': 'LA', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])]")
|
||||
MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"[register_user(name='Doe', "
|
||||
"age=9, "
|
||||
"address={'city': 'LA', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])]"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "Doe", '
|
||||
@@ -34,7 +38,7 @@ MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{}',
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
@@ -47,25 +51,28 @@ EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]")
|
||||
r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
PYTHON_TAG_FUNCTION_OUTPUT = (
|
||||
"<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>")
|
||||
"<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
mock_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
@@ -75,98 +82,139 @@ test_str = "<|python_start|>"
|
||||
test_str += "[get_weather(city='LA', metric='C'),"
|
||||
test_str += "register_user(name='Doe', age=9)]"
|
||||
TEST_CASES = [
|
||||
pytest.param(True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="simple_streaming"),
|
||||
pytest.param(False,
|
||||
SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming"),
|
||||
pytest.param(True,
|
||||
MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming"),
|
||||
pytest.param(False,
|
||||
MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming"),
|
||||
pytest.param(True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming"),
|
||||
pytest.param(False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming"),
|
||||
pytest.param(True,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming"),
|
||||
pytest.param(False,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming"),
|
||||
pytest.param(True,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming"),
|
||||
pytest.param(False,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming"),
|
||||
pytest.param(True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming"),
|
||||
pytest.param(False,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming"),
|
||||
pytest.param(
|
||||
True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False, SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], id="simple_nonstreaming"
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MORE_TYPES_FUNCTION_OUTPUT,
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MORE_TYPES_FUNCTION_OUTPUT,
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT,
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT,
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT,
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT,
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_streaming"),
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_nonstreaming"),
|
||||
pytest.param(True,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_streaming"),
|
||||
pytest.param(False,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_nonstreaming"),
|
||||
pytest.param(True,
|
||||
test_str, [
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
],
|
||||
id="parallel_calls_streaming"),
|
||||
pytest.param(False,
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), " +
|
||||
"register_user(name='Doe', age=9)]", [
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
],
|
||||
id="parallel_calls_nonstreaming"),
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
test_str,
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), "
|
||||
+ "register_user(name='Doe', age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
|
||||
TEST_CASES)
|
||||
def test_tool_call(streaming: bool, model_output: str,
|
||||
expected_tool_calls: list[FunctionCall]):
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
mock_tokenizer
|
||||
)
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
@@ -176,8 +224,9 @@ def test_tool_call(streaming: bool, model_output: str,
|
||||
|
||||
def test_streaming_tool_call_with_large_steps():
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
mock_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), "
|
||||
"get_weather(), "
|
||||
@@ -185,7 +234,8 @@ def test_streaming_tool_call_with_large_steps():
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False)
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
@@ -198,8 +248,9 @@ def test_streaming_tool_call_with_large_steps():
|
||||
def test_regex_timeout_handling(streaming: bool):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
mock_tokenizer
|
||||
)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
@@ -207,10 +258,10 @@ def test_regex_timeout_handling(streaming: bool):
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
fake_problematic_input,
|
||||
streaming=streaming)
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
|
||||
@@ -6,7 +6,9 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction, run_tool_extraction_streaming)
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
@@ -22,7 +24,8 @@ MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])")
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "John Doe", '
|
||||
@@ -35,7 +38,7 @@ MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{}',
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
@@ -48,7 +51,8 @@ EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')")
|
||||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
@@ -59,80 +63,118 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
def test_no_tool_call(streaming: bool):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
mock_tokenizer)
|
||||
mock_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
|
||||
id="simple_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming"),
|
||||
pytest.param(True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_streaming"),
|
||||
pytest.param(False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_nonstreaming"),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
|
||||
TEST_CASES)
|
||||
def test_tool_call(streaming: bool, model_output: str,
|
||||
expected_tool_calls: list[FunctionCall]):
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
mock_tokenizer)
|
||||
mock_tokenizer
|
||||
)
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content is None
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
@@ -144,7 +186,8 @@ def test_tool_call(streaming: bool, model_output: str,
|
||||
def test_streaming_tool_call_with_large_steps():
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
mock_tokenizer)
|
||||
mock_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"[get_weather(city='San",
|
||||
" Francisco', metric='celsius'), "
|
||||
@@ -153,7 +196,8 @@ def test_streaming_tool_call_with_large_steps():
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False)
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
@@ -166,8 +210,9 @@ def test_streaming_tool_call_with_large_steps():
|
||||
def test_regex_timeout_handling(streaming: bool):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
mock_tokenizer
|
||||
)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
@@ -175,10 +220,10 @@ def test_regex_timeout_handling(streaming: bool):
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex):
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
fake_problematic_input,
|
||||
streaming=streaming)
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
|
||||
@@ -4,15 +4,17 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
|
||||
|
||||
class StreamingToolReconstructor:
|
||||
|
||||
def __init__(self, assert_one_tool_per_delta: bool = True):
|
||||
self.tool_calls: list[ToolCall] = []
|
||||
self.other_content: str = ""
|
||||
@@ -23,49 +25,60 @@ class StreamingToolReconstructor:
|
||||
self.other_content += delta.content
|
||||
else:
|
||||
assert delta.tool_calls, (
|
||||
"Streaming results should have either content or tool calls "
|
||||
"(or both)")
|
||||
"Streaming results should have either content or tool calls (or both)"
|
||||
)
|
||||
if self._assert_one_tool_per_delta:
|
||||
# Note: This isn't strictly required by the API and may not be
|
||||
# possible to adhere to depending on the token space and number of
|
||||
# tokens per streamed response from the model, but it is required
|
||||
# by tool_use tests, so we enforce it here by default also.
|
||||
assert len(delta.tool_calls) < 2, (
|
||||
"Streaming should include only one tool call per update.")
|
||||
"Streaming should include only one tool call per update."
|
||||
)
|
||||
for call_delta in delta.tool_calls:
|
||||
assert call_delta.type is None or call_delta.type == "function", (
|
||||
"Streaming tool calls should only emit function calls. Got "
|
||||
f"{call_delta.type}")
|
||||
current_tool_call = self.tool_calls[
|
||||
call_delta.index] if call_delta.index < len(
|
||||
self.tool_calls) else None
|
||||
f"{call_delta.type}"
|
||||
)
|
||||
current_tool_call = (
|
||||
self.tool_calls[call_delta.index]
|
||||
if call_delta.index < len(self.tool_calls)
|
||||
else None
|
||||
)
|
||||
if current_tool_call:
|
||||
assert (not call_delta.function.name), (
|
||||
assert not call_delta.function.name, (
|
||||
"Streaming tool calls should emit the full function name "
|
||||
f"exactly once. Got {call_delta.function.name}")
|
||||
assert (not call_delta.id), (
|
||||
f"exactly once. Got {call_delta.function.name}"
|
||||
)
|
||||
assert not call_delta.id, (
|
||||
"Streaming tool calls must emit function id only once. Got "
|
||||
f"{call_delta.id}")
|
||||
assert (call_delta.index == len(self.tool_calls) - 1), (
|
||||
f"{call_delta.id}"
|
||||
)
|
||||
assert call_delta.index == len(self.tool_calls) - 1, (
|
||||
f"Incorrect index for tool delta. Got {call_delta.index}, "
|
||||
f"expected {len(self.tool_calls) - 1}")
|
||||
current_tool_call.function.arguments += (
|
||||
call_delta.function.arguments)
|
||||
f"expected {len(self.tool_calls) - 1}"
|
||||
)
|
||||
current_tool_call.function.arguments += call_delta.function.arguments
|
||||
else:
|
||||
assert call_delta.id is not None, (
|
||||
"Streaming tool calls must have an id on first appearance")
|
||||
"Streaming tool calls must have an id on first appearance"
|
||||
)
|
||||
assert call_delta.function.name is not None, (
|
||||
"Streaming tool calls must have a function name on first "
|
||||
"appearance")
|
||||
"Streaming tool calls must have a function name on first appearance"
|
||||
)
|
||||
assert call_delta.index == len(self.tool_calls), (
|
||||
f"Incorrect index for tool delta. Got {call_delta.index}, "
|
||||
f"expected {len(self.tool_calls)}")
|
||||
f"expected {len(self.tool_calls)}"
|
||||
)
|
||||
self.tool_calls.append(
|
||||
ToolCall(id=call_delta.id,
|
||||
function=FunctionCall(
|
||||
name=call_delta.function.name,
|
||||
arguments=call_delta.function.arguments
|
||||
or "")))
|
||||
ToolCall(
|
||||
id=call_delta.id,
|
||||
function=FunctionCall(
|
||||
name=call_delta.function.name,
|
||||
arguments=call_delta.function.arguments or "",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_tool_extraction(
|
||||
@@ -80,11 +93,11 @@ def run_tool_extraction(
|
||||
tool_parser,
|
||||
model_output,
|
||||
request,
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta)
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta,
|
||||
)
|
||||
return reconstructor.other_content or None, reconstructor.tool_calls
|
||||
else:
|
||||
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output,
|
||||
request)
|
||||
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, request)
|
||||
assert extracted.tools_called == bool(extracted.tool_calls)
|
||||
return extracted.content, extracted.tool_calls
|
||||
|
||||
@@ -92,7 +105,7 @@ def run_tool_extraction(
|
||||
def run_tool_extraction_nonstreaming(
|
||||
tool_parser: ToolParser,
|
||||
model_output: str,
|
||||
request: Union[ChatCompletionRequest, None] = None
|
||||
request: Union[ChatCompletionRequest, None] = None,
|
||||
) -> ExtractedToolCallInformation:
|
||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||
return tool_parser.extract_tool_calls(model_output, request)
|
||||
@@ -106,7 +119,8 @@ def run_tool_extraction_streaming(
|
||||
) -> StreamingToolReconstructor:
|
||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||
reconstructor = StreamingToolReconstructor(
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta)
|
||||
assert_one_tool_per_delta=assert_one_tool_per_delta
|
||||
)
|
||||
previous_text = ""
|
||||
previous_tokens: list[int] = []
|
||||
for delta in model_deltas:
|
||||
@@ -118,8 +132,14 @@ def run_tool_extraction_streaming(
|
||||
current_text = previous_text + delta
|
||||
current_tokens = previous_tokens + token_delta
|
||||
delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_text, current_text, delta, previous_tokens,
|
||||
current_tokens, token_delta, request)
|
||||
previous_text,
|
||||
current_text,
|
||||
delta,
|
||||
previous_tokens,
|
||||
current_tokens,
|
||||
token_delta,
|
||||
request,
|
||||
)
|
||||
if delta_message is not None:
|
||||
reconstructor.append_delta(delta_message)
|
||||
previous_text = current_text
|
||||
|
||||
Reference in New Issue
Block a user