180 lines
5.9 KiB
Python
180 lines
5.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# ruff: noqa: E501
|
|
|
|
import json
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from tests.tool_parsers.utils import (
|
|
run_tool_extraction,
|
|
run_tool_extraction_streaming,
|
|
)
|
|
from vllm.entrypoints.openai.engine.protocol import FunctionCall, ToolCall
|
|
from vllm.tool_parsers import ToolParser, ToolParserManager
|
|
|
|
|
|
def make_tool_call(name, arguments):
|
|
return ToolCall(
|
|
type="function",
|
|
function=FunctionCall(name=name, arguments=json.dumps(arguments)),
|
|
)
|
|
|
|
|
|
# TODO: add reason prefix and suffix.
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model_output,expected_tool_calls,expected_content",
|
|
[
|
|
# No tool call
|
|
("How can I help you today?", [], "How can I help you today?"),
|
|
# Single tool call, no content
|
|
(
|
|
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]</tool_calls>', # noqa: E501
|
|
[
|
|
make_tool_call(
|
|
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
|
)
|
|
],
|
|
None,
|
|
),
|
|
# Multiple tool calls
|
|
(
|
|
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]</tool_calls>', # noqa: E501
|
|
[
|
|
make_tool_call(
|
|
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
|
),
|
|
make_tool_call(
|
|
"register_user",
|
|
{
|
|
"name": "John Doe",
|
|
"age": 37,
|
|
"address": {"city": "San Francisco", "state": "CA"},
|
|
"role": None,
|
|
"passed_test": True,
|
|
"aliases": ["John", "Johnny"],
|
|
},
|
|
),
|
|
],
|
|
None,
|
|
),
|
|
# Content before tool call
|
|
(
|
|
'I will call the tool now. <tool_calls>[{"name": "get_weather", "arguments": {"city": "Boston"}}]</tool_calls>', # noqa: E501
|
|
[make_tool_call("get_weather", {"city": "Boston"})],
|
|
"I will call the tool now. ",
|
|
),
|
|
# Content after tool call (should be stripped)
|
|
(
|
|
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "Seattle"}}]</tool_calls>\nThank you!', # noqa: E501
|
|
[make_tool_call("get_weather", {"city": "Seattle"})],
|
|
None,
|
|
),
|
|
(
|
|
'<tool_calls>[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]</tool_calls>',
|
|
[
|
|
make_tool_call(
|
|
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
|
|
)
|
|
],
|
|
None,
|
|
),
|
|
],
|
|
)
|
|
def test_hunyuan_a13b_tool_parser_extract(
|
|
model_output, expected_tool_calls, expected_content
|
|
):
|
|
mock_tokenizer = MagicMock()
|
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
|
|
mock_tokenizer
|
|
)
|
|
content, tool_calls = run_tool_extraction(
|
|
tool_parser, model_output, streaming=False
|
|
)
|
|
|
|
# align the random id.
|
|
for idx in range(len(tool_calls)):
|
|
tool_calls[idx].id = expected_tool_calls[idx].id
|
|
assert tool_calls == expected_tool_calls
|
|
assert content == expected_content
|
|
|
|
|
|
# Streaming test: simulate incremental output
|
|
@pytest.mark.parametrize(
|
|
"model_deltas,expected_tool_calls",
|
|
[
|
|
(
|
|
[
|
|
'<tool_calls>[{"name": "get_weather", ',
|
|
'"arguments": {"city": "San Francisco", ',
|
|
'"metric": "celsius"}}]',
|
|
"</tool_calls>",
|
|
],
|
|
[
|
|
make_tool_call(
|
|
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
|
)
|
|
],
|
|
),
|
|
(
|
|
[
|
|
'<tool_calls>[{"name":',
|
|
' "get_weather",',
|
|
' "arguments":',
|
|
' {"city": "Boston"}',
|
|
"}]",
|
|
"</tool_calls>",
|
|
],
|
|
[make_tool_call("get_weather", {"city": "Boston"})],
|
|
),
|
|
(
|
|
[
|
|
"",
|
|
'<tool_calls>[{"name":',
|
|
' "get_weather",',
|
|
' "arguments":',
|
|
' {"city": "Boston"}',
|
|
"}]",
|
|
"</tool_calls>",
|
|
"\n</answer>",
|
|
],
|
|
[make_tool_call("get_weather", {"city": "Boston"})],
|
|
),
|
|
pytest.param(
|
|
[
|
|
'<tool_calls>[{"name": "complex_tool",',
|
|
' "arguments": ',
|
|
' {"level1": {"level2": ',
|
|
'{"level3": {"value": 123}}}}}',
|
|
"]</tool_calls>",
|
|
],
|
|
[
|
|
make_tool_call(
|
|
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
|
|
)
|
|
],
|
|
marks=pytest.mark.xfail(
|
|
reason="stream parsing not support nested json yet."
|
|
),
|
|
),
|
|
],
|
|
)
|
|
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
|
|
mock_tokenizer = MagicMock()
|
|
|
|
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
|
|
mock_tokenizer
|
|
)
|
|
reconstructor = run_tool_extraction_streaming(
|
|
tool_parser, model_deltas, assert_one_tool_per_delta=False
|
|
)
|
|
|
|
# align the random id.
|
|
for idx in range(len(reconstructor.tool_calls)):
|
|
reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id
|
|
|
|
assert reconstructor.tool_calls == expected_tool_calls
|