# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa: E501 import json from unittest.mock import MagicMock import pytest from tests.tool_parsers.utils import ( run_tool_extraction, run_tool_extraction_streaming, ) from vllm.entrypoints.openai.engine.protocol import FunctionCall, ToolCall from vllm.tool_parsers import ToolParser, ToolParserManager def make_tool_call(name, arguments): return ToolCall( type="function", function=FunctionCall(name=name, arguments=json.dumps(arguments)), ) # TODO: add reason prefix and suffix. @pytest.mark.parametrize( "model_output,expected_tool_calls,expected_content", [ # No tool call ("How can I help you today?", [], "How can I help you today?"), # Single tool call, no content ( '[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]', # noqa: E501 [ make_tool_call( "get_weather", {"city": "San Francisco", "metric": "celsius"} ) ], None, ), # Multiple tool calls ( '[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]', # noqa: E501 [ make_tool_call( "get_weather", {"city": "San Francisco", "metric": "celsius"} ), make_tool_call( "register_user", { "name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": None, "passed_test": True, "aliases": ["John", "Johnny"], }, ), ], None, ), # Content before tool call ( 'I will call the tool now. [{"name": "get_weather", "arguments": {"city": "Boston"}}]', # noqa: E501 [make_tool_call("get_weather", {"city": "Boston"})], "I will call the tool now. ", ), # Content after tool call (should be stripped) ( '[{"name": "get_weather", "arguments": {"city": "Seattle"}}]\nThank you!', # noqa: E501 [make_tool_call("get_weather", {"city": "Seattle"})], None, ), ( '[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]', [ make_tool_call( "complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}} ) ], None, ), ], ) def test_hunyuan_a13b_tool_parser_extract( model_output, expected_tool_calls, expected_content ): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")( mock_tokenizer ) content, tool_calls = run_tool_extraction( tool_parser, model_output, streaming=False ) # align the random id. for idx in range(len(tool_calls)): tool_calls[idx].id = expected_tool_calls[idx].id assert tool_calls == expected_tool_calls assert content == expected_content # Streaming test: simulate incremental output @pytest.mark.parametrize( "model_deltas,expected_tool_calls", [ ( [ '[{"name": "get_weather", ', '"arguments": {"city": "San Francisco", ', '"metric": "celsius"}}]', "", ], [ make_tool_call( "get_weather", {"city": "San Francisco", "metric": "celsius"} ) ], ), ( [ '[{"name":', ' "get_weather",', ' "arguments":', ' {"city": "Boston"}', "}]", "", ], [make_tool_call("get_weather", {"city": "Boston"})], ), ( [ "", '[{"name":', ' "get_weather",', ' "arguments":', ' {"city": "Boston"}', "}]", "", "\n", ], [make_tool_call("get_weather", {"city": "Boston"})], ), pytest.param( [ '[{"name": "complex_tool",', ' "arguments": ', ' {"level1": {"level2": ', '{"level3": {"value": 123}}}}}', "]", ], [ make_tool_call( "complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}} ) ], marks=pytest.mark.xfail( reason="stream parsing not support nested json yet." ), ), ], ) def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")( mock_tokenizer ) reconstructor = run_tool_extraction_streaming( tool_parser, model_deltas, assert_one_tool_per_delta=False ) # align the random id. for idx in range(len(reconstructor.tool_calls)): reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id assert reconstructor.tool_calls == expected_tool_calls