[Test] Consolidate tool parser unit tests to tests/tool_parsers (#37834)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning
2026-03-23 00:24:25 -04:00
committed by GitHub
parent 6e04e79326
commit 3bbe2e1e6e
11 changed files with 376 additions and 353 deletions

View File

@@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from vllm.tokenizers import TokenizerLike
@pytest.fixture(scope="function")
def default_tokenizer() -> TokenizerLike:
return AutoTokenizer.from_pretrained("gpt2")

View File

@@ -1,343 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from tests.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.engine.protocol import FunctionCall
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
MSG_SEP_TOKEN = "<|message_sep|>\n\n"
ROLE_SEP_TOKEN = "<|role_sep|>\n"
EOS_TOKEN = "</s>"
TOOL_HEADER_GIGACHAT3 = f"function call{ROLE_SEP_TOKEN}"
TOOL_HEADER_GIGACHAT31 = "<|function_call|>"
SIMPLE_ARGS_DICT = {
"action": "create",
"id": "preferences",
}
SIMPLE_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": SIMPLE_ARGS_DICT,
},
ensure_ascii=False,
)
SIMPLE_FUNCTION_OUTPUT_GIGACHAT3 = (
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{SIMPLE_FUNCTION_JSON}"
)
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{SIMPLE_FUNCTION_JSON}"
SIMPLE_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False),
)
PARAMETERLESS_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": {},
},
ensure_ascii=False,
)
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3 = (
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{PARAMETERLESS_FUNCTION_JSON}"
)
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31 = (
f"{TOOL_HEADER_GIGACHAT31}{PARAMETERLESS_FUNCTION_JSON}"
)
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps({}, ensure_ascii=False),
)
COMPLEX_ARGS_DICT = {
"action": "create",
"id": "preferences",
"content": {
"short_answers": True,
"hate_emojis": True,
"english_ui": False,
"russian_math_explanations": True,
},
}
COMPLEX_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": COMPLEX_ARGS_DICT,
},
ensure_ascii=False,
)
COMPLEX_FUNCTION_OUTPUT_GIGACHAT3 = (
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{COMPLEX_FUNCTION_JSON}"
)
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{COMPLEX_FUNCTION_JSON}"
COMPLEX_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False),
)
CONTENT_TEXT = "I'll check that for you."
MIXED_OUTPUT_GIGACHAT3 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT3}"
MIXED_OUTPUT_GIGACHAT31 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT31}"
@pytest.fixture(name="gigachat_tokenizer")
def fixture_gigachat_tokenizer(default_tokenizer: TokenizerLike):
default_tokenizer.add_tokens(
[
MSG_SEP_TOKEN,
ROLE_SEP_TOKEN,
TOOL_HEADER_GIGACHAT31,
EOS_TOKEN,
]
)
return default_tokenizer
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, gigachat_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
gigachat_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(
True,
SIMPLE_FUNCTION_OUTPUT_GIGACHAT3,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_streaming_gigachat3",
),
pytest.param(
False,
SIMPLE_FUNCTION_OUTPUT_GIGACHAT3,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_nonstreaming_gigachat3",
),
pytest.param(
True,
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_streaming_gigachat3",
),
pytest.param(
False,
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_nonstreaming_gigachat3",
),
pytest.param(
True,
COMPLEX_FUNCTION_OUTPUT_GIGACHAT3,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_streaming_gigachat3",
),
pytest.param(
False,
COMPLEX_FUNCTION_OUTPUT_GIGACHAT3,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_nonstreaming_gigachat3",
),
pytest.param(
True,
MIXED_OUTPUT_GIGACHAT3,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_streaming_gigachat3",
),
pytest.param(
False,
MIXED_OUTPUT_GIGACHAT3,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_nonstreaming_gigachat3",
),
pytest.param(
True,
MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_streaming_with_eos_gigachat3",
),
pytest.param(
False,
MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_nonstreaming_with_eos_gigachat3",
),
pytest.param(
True,
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_streaming_gigachat31",
),
pytest.param(
False,
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_nonstreaming_gigachat31",
),
pytest.param(
True,
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_streaming_gigachat31",
),
pytest.param(
False,
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_nonstreaming_gigachat31",
),
pytest.param(
True,
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_streaming_gigachat31",
),
pytest.param(
False,
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_nonstreaming_gigachat31",
),
pytest.param(
True,
MIXED_OUTPUT_GIGACHAT31,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_streaming_gigachat31",
),
pytest.param(
False,
MIXED_OUTPUT_GIGACHAT31,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_nonstreaming_gigachat31",
),
pytest.param(
True,
MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_streaming_with_eos_gigachat31",
),
pytest.param(
False,
MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN,
[SIMPLE_FUNCTION_CALL],
CONTENT_TEXT,
id="mixed_content_nonstreaming_with_eos_gigachat31",
),
]
@pytest.mark.parametrize(
"streaming, model_output, expected_tool_calls, expected_content", TEST_CASES
)
def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
expected_content: str | None,
gigachat_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
gigachat_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
if content == "":
content = None
assert content == expected_content
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function.name == expected.name
actual_args = json.loads(actual.function.arguments)
expected_args = json.loads(expected.arguments)
assert actual_args == expected_args
@pytest.mark.parametrize(
"model_output_deltas",
[
pytest.param(
[
CONTENT_TEXT[:3],
CONTENT_TEXT[3:5],
CONTENT_TEXT[5:],
MSG_SEP_TOKEN,
TOOL_HEADER_GIGACHAT3,
COMPLEX_FUNCTION_JSON[:40],
COMPLEX_FUNCTION_JSON[40:-1],
COMPLEX_FUNCTION_JSON[-1],
],
id="gigachat3",
),
pytest.param(
[
CONTENT_TEXT[:3],
CONTENT_TEXT[3:5],
CONTENT_TEXT[5:],
TOOL_HEADER_GIGACHAT31,
COMPLEX_FUNCTION_JSON[:40],
COMPLEX_FUNCTION_JSON[40:-1],
COMPLEX_FUNCTION_JSON[-1],
],
id="gigachat31",
),
],
)
def test_streaming_tool_call_with_large_steps(
model_output_deltas: list[str],
gigachat_tokenizer: TokenizerLike,
):
"""
Test that the closing braces are streamed correctly.
"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
gigachat_tokenizer
)
reconstructor = run_tool_extraction_streaming(
tool_parser,
model_output_deltas,
assert_one_tool_per_delta=False,
)
assert len(reconstructor.tool_calls) == 1
call = reconstructor.tool_calls[0]
assert call.type == "function"
assert call.function.name == "manage_user_memory"
args_dict = json.loads(call.function.arguments)
assert args_dict == COMPLEX_ARGS_DICT

View File

@@ -1,18 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import random
from typing import Any
import openai
import pytest
from transformers import AutoTokenizer
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
)
from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
from ....utils import RemoteOpenAIServer
@@ -38,137 +29,6 @@ def server():
yield server
def create_complex_input(create_string_args: bool):
coord_arg: dict | str = {
"coordinates": [[23.54, 43.1], [-12.2, 54.3], [4, 5]],
"coordinate_type": "latlong",
}
if create_string_args:
# test granite behavior
coord_arg = json.dumps(coord_arg)
return [
{"name": "find_bbox", "arguments": coord_arg},
{
"name": "get_stock_price",
"arguments": {
"symbol": "AAPL",
"start_date": "2021-01-01",
"end_date": "2021-12-31",
},
},
{"name": "find_bbox", "arguments": coord_arg},
]
def random_chunks(s: str, min_len: int, max_len: int):
chunks = []
i = 0
n = len(s)
while i < n:
size = random.randint(min_len, max_len)
chunks.append(s[i : i + size])
i += size
return chunks
@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(MODEL)
# create a variety of input chunk sizes
@pytest.mark.parametrize(
"min_chunk, max_chunk",
[
(1, 1),
(1, 2),
(5, 7),
(6, 20),
],
)
def test_tool_call_parser_complex(min_chunk: int, max_chunk: int, tokenizer):
input_dicts = create_complex_input(True)
formatted_tcs = [
"<tool_call> " + json.dumps(call) + " </tool_call>" for call in input_dicts
]
text_messages = [
"Here goes the bbox call: \n",
" Now the stock price call: \n ",
" Now another bbox call: \n ",
" See? I'm a helpful assistant.",
]
test_input = (
text_messages[0]
+ formatted_tcs[0]
+ text_messages[1]
+ formatted_tcs[1]
+ text_messages[2]
+ formatted_tcs[2]
+ text_messages[3]
)
any_chat_request = ChatCompletionRequest(
seed=42,
model=MODEL,
messages=[],
)
parser = Granite4ToolParser(tokenizer=tokenizer)
delta_messages = list[DeltaMessage]()
for text in random_chunks(test_input, min_chunk, max_chunk):
delta = parser.extract_tool_calls_streaming(
previous_text="",
current_text="",
delta_text=text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
if delta is not None:
delta_messages.append(delta)
content = ""
tool_calls = list[dict[str, Any]]()
current_name = "__start__"
current_args = ""
for msg in delta_messages:
if msg.content:
content += msg.content
for tool_call in msg.tool_calls:
if delta_func := tool_call.function:
if delta_func.name is not None:
if current_name == "__start__":
current_name = delta_func.name
if delta_func.name != current_name:
tool_calls.append(
{
"name": current_name,
"arguments": json.loads(current_args),
}
)
current_name = delta_func.name
current_args = ""
if delta_func.arguments:
current_args += delta_func.arguments
if current_name != "__start__":
tool_calls.append({"name": current_name, "arguments": json.loads(current_args)})
assert content == "".join(text_messages)
assert tool_calls == create_complex_input(False)
tools = [
{
"type": "function",

View File

@@ -9,8 +9,6 @@ import pytest_asyncio
from huggingface_hub import snapshot_download
from typing_extensions import TypedDict
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
@@ -325,202 +323,3 @@ async def test_streaming_product_tool_call(
print("\n[Streaming Product Test Passed]")
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
print(f"Reconstructed Arguments: {arguments}")
@pytest.fixture
def qwen_tokenizer() -> TokenizerLike:
from vllm.tokenizers import get_tokenizer
return get_tokenizer("Qwen/Qwen3-32B")
@pytest.fixture(params=CONFIGS.keys())
def hermes_parser(request, qwen_tokenizer: TokenizerLike) -> ToolParser:
config = CONFIGS[request.param]
return config["tool_parser"](qwen_tokenizer)
@pytest.fixture
def any_chat_request() -> ChatCompletionRequest:
return ChatCompletionRequest(
seed=42,
model="Qwen/Qwen3-32B",
messages=[],
)
def test_hermes_parser_streaming_just_forward_text(
qwen_tokenizer: TokenizerLike,
hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """This is some prior text that has nothing to do with tool calling."""
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
delta_text = qwen_tokenizer.decode([token])
current_text = previous_text + delta_text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
delta_messages.append(delta)
for delta in delta_messages:
assert delta is not None
assert not delta.tool_calls
print(delta_messages)
assert "".join([delta.content for delta in delta_messages]) == text
def test_hermes_parser_streaming_failure_case_bug_19056(
qwen_tokenizer: TokenizerLike,
hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}
</tool_call>"""
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
text = qwen_tokenizer.decode([token])
current_text = previous_text + text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
if delta is not None:
delta_messages.append(delta)
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
tool_call_args = "".join(
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
)
assert tool_call_args == '{"trigger": true}'
def test_hermes_parser_streaming(
qwen_tokenizer: TokenizerLike,
hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = '<tool_call>\
{"name": "get_current_temperature",\
"arguments": {"location":\
"San Francisco, California, United States", "unit": "celsius"}}\
</tool_call>'
tokens = qwen_tokenizer.encode(text)
previous_text = ""
delta_messages = []
for token in tokens:
text = qwen_tokenizer.decode([token])
current_text = previous_text + text
delta = hermes_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=any_chat_request,
)
previous_text = current_text
if delta is not None:
delta_messages.append(delta)
print(delta_messages)
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
# load to normalize whitespace
tool_call_args = json.loads(
"".join(
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
)
)
assert tool_call_args == {
"location": "San Francisco, California, United States",
"unit": "celsius",
}
def test_hermes_parser_non_streaming_no_tool_call(
hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """This is not a tool call."""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert not tool_call.tools_called
def test_hermes_parser_non_streaming_tool_call_between_tags(
hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}
</tool_call>"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert tool_call.tools_called
assert tool_call.tool_calls[0].function.name == "final_answer"
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
def test_hermes_parser_non_streaming_tool_call_until_eos(
hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
if isinstance(hermes_parser, Granite4ToolParser):
pytest.skip(reason="The Granite4 tool parser enforces a complete response")
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}}"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert tool_call.tools_called
assert tool_call.tool_calls[0].function.name == "final_answer"
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
def test_hermes_parser_non_streaming_tool_call_invalid_json(
hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
# Missing closing brace to trigger exception
text = """<tool_call>
{"name": "final_answer", "arguments": {"trigger": true}"""
tool_call = hermes_parser.extract_tool_calls(
model_output=text,
request=any_chat_request,
)
assert tool_call is not None
assert not tool_call.tools_called

View File

@@ -1,179 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
from unittest.mock import MagicMock
import pytest
from tests.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.engine.protocol import FunctionCall, ToolCall
from vllm.tool_parsers import ToolParser, ToolParserManager
def make_tool_call(name, arguments):
return ToolCall(
type="function",
function=FunctionCall(name=name, arguments=json.dumps(arguments)),
)
# TODO: add reason prefix and suffix.
@pytest.mark.parametrize(
"model_output,expected_tool_calls,expected_content",
[
# No tool call
("How can I help you today?", [], "How can I help you today?"),
# Single tool call, no content
(
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]</tool_calls>', # noqa: E501
[
make_tool_call(
"get_weather", {"city": "San Francisco", "metric": "celsius"}
)
],
None,
),
# Multiple tool calls
(
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]</tool_calls>', # noqa: E501
[
make_tool_call(
"get_weather", {"city": "San Francisco", "metric": "celsius"}
),
make_tool_call(
"register_user",
{
"name": "John Doe",
"age": 37,
"address": {"city": "San Francisco", "state": "CA"},
"role": None,
"passed_test": True,
"aliases": ["John", "Johnny"],
},
),
],
None,
),
# Content before tool call
(
'I will call the tool now. <tool_calls>[{"name": "get_weather", "arguments": {"city": "Boston"}}]</tool_calls>', # noqa: E501
[make_tool_call("get_weather", {"city": "Boston"})],
"I will call the tool now. ",
),
# Content after tool call (should be stripped)
(
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "Seattle"}}]</tool_calls>\nThank you!', # noqa: E501
[make_tool_call("get_weather", {"city": "Seattle"})],
None,
),
(
'<tool_calls>[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]</tool_calls>',
[
make_tool_call(
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
)
],
None,
),
],
)
def test_hunyuan_a13b_tool_parser_extract(
model_output, expected_tool_calls, expected_content
):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
mock_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=False
)
# align the random id.
for idx in range(len(tool_calls)):
tool_calls[idx].id = expected_tool_calls[idx].id
assert tool_calls == expected_tool_calls
assert content == expected_content
# Streaming test: simulate incremental output
@pytest.mark.parametrize(
"model_deltas,expected_tool_calls",
[
(
[
'<tool_calls>[{"name": "get_weather", ',
'"arguments": {"city": "San Francisco", ',
'"metric": "celsius"}}]',
"</tool_calls>",
],
[
make_tool_call(
"get_weather", {"city": "San Francisco", "metric": "celsius"}
)
],
),
(
[
'<tool_calls>[{"name":',
' "get_weather",',
' "arguments":',
' {"city": "Boston"}',
"}]",
"</tool_calls>",
],
[make_tool_call("get_weather", {"city": "Boston"})],
),
(
[
"",
'<tool_calls>[{"name":',
' "get_weather",',
' "arguments":',
' {"city": "Boston"}',
"}]",
"</tool_calls>",
"\n</answer>",
],
[make_tool_call("get_weather", {"city": "Boston"})],
),
pytest.param(
[
'<tool_calls>[{"name": "complex_tool",',
' "arguments": ',
' {"level1": {"level2": ',
'{"level3": {"value": 123}}}}}',
"]</tool_calls>",
],
[
make_tool_call(
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
)
],
marks=pytest.mark.xfail(
reason="stream parsing not support nested json yet."
),
),
],
)
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
mock_tokenizer
)
reconstructor = run_tool_extraction_streaming(
tool_parser, model_deltas, assert_one_tool_per_delta=False
)
# align the random id.
for idx in range(len(reconstructor.tool_calls)):
reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id
assert reconstructor.tool_calls == expected_tool_calls

View File

@@ -1,262 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from vllm.entrypoints.openai.engine.protocol import ExtractedToolCallInformation
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.llama_tool_parser import Llama3JsonToolParser
@pytest.fixture
def parser(default_tokenizer: TokenizerLike):
return Llama3JsonToolParser(default_tokenizer)
def test_extract_tool_calls_simple(parser):
# Test with a simple tool call
model_output = (
'Here is the result: {"name": "getOpenIncidentsTool", '
'"parameters": {}} Would you like to know more?'
)
result = parser.extract_tool_calls(model_output, None)
assert isinstance(result, ExtractedToolCallInformation)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].type == "function"
assert result.tool_calls[0].function.name == "getOpenIncidentsTool"
assert result.tool_calls[0].function.arguments == "{}"
assert result.content is None
def test_extract_tool_calls_with_arguments(parser):
# Test with a tool call that has arguments
model_output = (
'{"name": "searchTool", "parameters": {"query": "test query", "limit": 10}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "searchTool"
assert '"query": "test query"' in result.tool_calls[0].function.arguments
assert '"limit": 10' in result.tool_calls[0].function.arguments
def test_extract_tool_calls_no_json(parser):
# Test with text that doesn't contain a JSON object
model_output = "This is just some text without any tool calls"
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is False
assert len(result.tool_calls) == 0
assert result.content == model_output
def test_extract_tool_calls_invalid_json(parser):
# Test with invalid JSON
model_output = '{"name": "invalidTool", "parameters": {invalid json}'
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is False
assert len(result.tool_calls) == 0
assert result.content == model_output
def test_extract_tool_calls_with_arguments_key(parser):
# Test with a tool call that uses "arguments" instead of "parameters"
model_output = '{"name": "searchTool", "arguments": {"query": "test"}}'
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "searchTool"
assert '"query": "test"' in result.tool_calls[0].function.arguments
def test_extract_tool_calls_multiple_json(parser):
# Test with multiple JSONs separated by semicolons
model_output = (
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
'{"name": "searchTool", "parameters": {"query": "test2"}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 3
# Check first tool call
assert result.tool_calls[0].function.name == "searchTool"
assert '"query": "test1"' in result.tool_calls[0].function.arguments
# Check second tool call
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
assert result.tool_calls[1].function.arguments == "{}"
# Check third tool call
assert result.tool_calls[2].function.name == "searchTool"
assert '"query": "test2"' in result.tool_calls[2].function.arguments
def test_extract_tool_calls_multiple_json_with_whitespace(parser):
# Test with multiple JSONs separated by semicolons and extra whitespace
model_output = (
'{"name": "searchTool", "parameters": {"query": "test1"}} ; '
'{"name": "getOpenIncidentsTool", "parameters": {}} ; '
'{"name": "searchTool", "parameters": {"query": "test2"}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 3
assert result.tool_calls[0].function.name == "searchTool"
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
assert result.tool_calls[2].function.name == "searchTool"
def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
# Test with multiple JSONs and surrounding text
model_output = (
"Here are the results: "
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
'{"name": "searchTool", "parameters": {"query": "test2"}} '
"Would you like to know more?"
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 3
assert result.tool_calls[0].function.name == "searchTool"
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
assert result.tool_calls[2].function.name == "searchTool"
def test_extract_tool_calls_deeply_nested_json(parser):
# Test with deeply nested JSON parameters (5 levels)
model_output = (
'{"name": "complexTool", '
'"parameters": {'
'"level1": {'
'"level2": {'
'"level3": {'
'"level4": {'
'"value": "deep"'
"}}}}}}"
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "complexTool"
# Verify the nested structure is preserved in the arguments
import json
args = json.loads(result.tool_calls[0].function.arguments)
assert args["level1"]["level2"]["level3"]["level4"]["value"] == "deep"
def test_extract_tool_calls_multiple_with_deep_nesting(parser):
# Test with multiple tool calls where some have deeply nested parameters
model_output = (
'{"name": "simpleTool", "parameters": {"value": "test"}}; '
'{"name": "complexTool", "parameters": '
'{"config": {"database": {"connection": {"pool": {"size": 10}}}}}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 2
# Check first tool call
assert result.tool_calls[0].function.name == "simpleTool"
import json
args0 = json.loads(result.tool_calls[0].function.arguments)
assert args0["value"] == "test"
# Check second tool call with deep nesting
assert result.tool_calls[1].function.name == "complexTool"
args1 = json.loads(result.tool_calls[1].function.arguments)
assert args1["config"]["database"]["connection"]["pool"]["size"] == 10
def test_extract_tool_calls_with_quotes_and_brackets_in_string(parser):
# Test with quotes and brackets inside quoted string values
model_output = (
'{"name": "searchTool", '
'"parameters": {'
'"query": "test {value} [complex]",'
'"nested": {"inner": "more {brackets}"}'
"}}"
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "searchTool"
# Verify the string values are preserved including brackets and quotes
import json
args = json.loads(result.tool_calls[0].function.arguments)
assert args["query"] == "test {value} [complex]"
assert args["nested"]["inner"] == "more {brackets}"
def test_extract_tool_calls_with_escaped_quotes_in_nested_json(parser):
# Test with escaped quotes in deeply nested JSON
model_output = (
'{"name": "parserTool", "parameters": {"text": "He said \\"Hello {world}\\""}}'
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "parserTool"
# Verify escaped quotes are preserved
import json
args = json.loads(result.tool_calls[0].function.arguments)
assert args["text"] == 'He said "Hello {world}"'
def test_extract_tool_calls_missing_name_key(parser):
# Test that missing "name" key returns content
model_output = '{"parameters": {}}'
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is False
assert len(result.tool_calls) == 0
assert result.content == model_output
def test_extract_tool_calls_missing_parameters_and_arguments_key(parser):
# Test that missing both "parameters" and "arguments" keys returns content
model_output = '{"name": "toolWithoutParams"}'
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called is False
assert len(result.tool_calls) == 0
assert result.content == model_output
def test_regex_timeout_handling(parser):
"""Test regex timeout is handled gracefully"""
fake_problematic_input = "{hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.finditer.side_effect = TimeoutError("Regex timeout")
with patch.object(parser, "tool_call_start_regex", mock_regex):
result = parser.extract_tool_calls(fake_problematic_input, None)
# should treat as regular text when regex times out
assert result.content == fake_problematic_input
assert result.tools_called is False
assert len(result.tool_calls) == 0
mock_regex.finditer.assert_called_once()

View File

@@ -1,269 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from tests.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.engine.protocol import FunctionCall
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
# Test cases similar to pythonic parser but with Llama4 specific format
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
SIMPLE_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "LA", "metric": "C"}',
)
MORE_TYPES_FUNCTION_OUTPUT = (
"[register_user(name='Doe', "
"age=9, "
"address={'city': 'LA', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])]"
)
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "Doe", '
'"age": 9, '
'"address": {"city": "LA", "state": "CA"}, '
'"role": null, '
'"passed_test": true, '
'"aliases": ["John", "Johnny"]}',
)
PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments="{}",
)
EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"additional_data": {}}',
)
EMPTY_LIST_FUNCTION_OUTPUT = "[do_something_cool(steps=[])]"
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]"
)
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
)
PYTHON_TAG_FUNCTION_OUTPUT = (
"<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>"
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
test_str = "<|python_start|>"
test_str += "[get_weather(city='LA', metric='C'),"
test_str += "register_user(name='Doe', age=9)]"
TEST_CASES = [
pytest.param(
True,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="simple_streaming",
),
pytest.param(
False, SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], id="simple_nonstreaming"
),
pytest.param(
True,
MORE_TYPES_FUNCTION_OUTPUT,
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming",
),
pytest.param(
False,
MORE_TYPES_FUNCTION_OUTPUT,
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming",
),
pytest.param(
True,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming",
),
pytest.param(
False,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming",
),
pytest.param(
True,
EMPTY_DICT_FUNCTION_OUTPUT,
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming",
),
pytest.param(
False,
EMPTY_DICT_FUNCTION_OUTPUT,
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming",
),
pytest.param(
True,
EMPTY_LIST_FUNCTION_OUTPUT,
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming",
),
pytest.param(
False,
EMPTY_LIST_FUNCTION_OUTPUT,
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming",
),
pytest.param(
True,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming",
),
pytest.param(
False,
ESCAPED_STRING_FUNCTION_OUTPUT,
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming",
),
pytest.param(
True,
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_streaming",
),
pytest.param(
False,
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_nonstreaming",
),
pytest.param(
True,
PYTHON_TAG_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
id="python_tag_streaming",
),
pytest.param(
False,
PYTHON_TAG_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
id="python_tag_nonstreaming",
),
pytest.param(
True,
test_str,
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_streaming",
),
pytest.param(
False,
"<|python_start|>[get_weather(city='LA', metric='C'), "
"register_user(name='Doe', age=9)]",
[
SIMPLE_FUNCTION_CALL,
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
],
id="parallel_calls_nonstreaming",
),
]
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
model_output_deltas = [
"<|python_start|>[get_weather(city='LA', metric='C'), "
"get_weather(), "
"do_something_cool(steps=[])]<|python_end|>",
]
reconstructor = run_tool_extraction_streaming(
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
)
assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
default_tokenizer
)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
content, tool_calls = run_tool_extraction(
tool_parser, fake_problematic_input, streaming=streaming
)
# should treat as regular text when regex times out
assert content == fake_problematic_input
assert len(tool_calls) == 0
mock_regex.match.assert_called_once()

View File

@@ -1,251 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from tests.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.engine.protocol import FunctionCall
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
SIMPLE_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "San Francisco", "metric": "celsius"}',
)
MORE_TYPES_FUNCTION_OUTPUT = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])"
)
MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=null, "
"passed_test=true, "
"aliases=['John', 'Johnny'])"
)
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "John Doe", '
'"age": 37, '
'"address": {"city": "San Francisco", "state": "CA"}, '
'"role": null, '
'"passed_test": true, '
'"aliases": ["John", "Johnny"]}',
)
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments="{}",
)
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"additional_data": {}}',
)
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
)
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(
True,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL],
id="simple_streaming",
),
pytest.param(
False,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL],
id="simple_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming",
),
pytest.param(
False,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming_json_literals",
),
pytest.param(
False,
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming_json_literals",
),
pytest.param(
True,
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming",
),
pytest.param(
False,
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming",
),
pytest.param(
False,
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming",
),
pytest.param(
False,
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming",
),
pytest.param(
False,
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming",
),
pytest.param(
True,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_streaming",
),
pytest.param(
False,
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_nonstreaming",
),
]
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content is None
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
model_output_deltas = [
"<function_calls>get_weather(city='San",
" Francisco', metric='celsius')\n"
f"{PARAMETERLESS_FUNCTION_OUTPUT}\n"
f"{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
]
reconstructor = run_tool_extraction_streaming(
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
)
assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
default_tokenizer
)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
content, tool_calls = run_tool_extraction(
tool_parser, fake_problematic_input, streaming=streaming
)
# should treat as regular text when regex times out
assert content == fake_problematic_input
assert len(tool_calls) == 0
mock_regex.match.assert_called_once()

View File

@@ -1,231 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from tests.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.engine.protocol import FunctionCall
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
SIMPLE_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "San Francisco", "metric": "celsius"}',
)
MORE_TYPES_FUNCTION_OUTPUT = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])"
)
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "John Doe", '
'"age": 37, '
'"address": {"city": "San Francisco", "state": "CA"}, '
'"role": null, '
'"passed_test": true, '
'"aliases": ["John", "Johnny"]}',
)
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments="{}",
)
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"additional_data": {}}',
)
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
)
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(
True,
f"[{SIMPLE_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL],
id="simple_streaming",
),
pytest.param(
False,
f"[{SIMPLE_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL],
id="simple_nonstreaming",
),
pytest.param(
True,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming",
),
pytest.param(
False,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
[MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming",
),
pytest.param(
True,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming",
),
pytest.param(
False,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming",
),
pytest.param(
True,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming",
),
pytest.param(
False,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
[EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming",
),
pytest.param(
True,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming",
),
pytest.param(
False,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
[EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming",
),
pytest.param(
True,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming",
),
pytest.param(
False,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming",
),
pytest.param(
True,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_streaming",
),
pytest.param(
False,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_nonstreaming",
),
]
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content is None
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function == expected
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
model_output_deltas = [
"[get_weather(city='San",
" Francisco', metric='celsius'), "
f"{PARAMETERLESS_FUNCTION_OUTPUT}, "
f"{EMPTY_LIST_FUNCTION_OUTPUT}]",
]
reconstructor = run_tool_extraction_streaming(
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
)
assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
@pytest.mark.parametrize("streaming", [False])
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
"""test regex timeout is handled gracefully"""
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
default_tokenizer
)
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
# create a mock regex that raises TimeoutError
mock_regex = MagicMock()
mock_regex.match.side_effect = TimeoutError("Regex timeout")
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
content, tool_calls = run_tool_extraction(
tool_parser, fake_problematic_input, streaming=streaming
)
# should treat as regular text when regex times out
assert content == fake_problematic_input
assert len(tool_calls) == 0
mock_regex.match.assert_called_once()