# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from unittest.mock import MagicMock import pytest from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.tool_parsers.functiongemma_tool_parser import FunctionGemmaToolParser @pytest.fixture def mock_tokenizer(): tokenizer = MagicMock() tokenizer.encode.return_value = [1, 2, 3] tokenizer.get_vocab.return_value = {} return tokenizer @pytest.fixture def parser(mock_tokenizer): return FunctionGemmaToolParser(mock_tokenizer) @pytest.fixture def mock_request(): request = MagicMock(spec=ChatCompletionRequest) request.tools = [] request.tool_choice = "auto" return request class TestExtractToolCalls: def test_no_tool_calls(self, parser, mock_request): model_output = "Hello, how can I help you today?" result = parser.extract_tool_calls(model_output, mock_request) assert result.tools_called is False assert result.tool_calls == [] assert result.content == model_output def test_single_tool_call(self, parser, mock_request): model_output = ( "call:get_weather{location:London}" "" ) result = parser.extract_tool_calls(model_output, mock_request) assert result.tools_called is True assert len(result.tool_calls) == 1 assert result.tool_calls[0].function.name == "get_weather" assert '"location": "London"' in result.tool_calls[0].function.arguments def test_multiple_arguments(self, parser, mock_request): model_output = ( "call:get_weather{" "location:San Francisco," "unit:celsius}" "" ) result = parser.extract_tool_calls(model_output, mock_request) assert result.tools_called is True assert len(result.tool_calls) == 1 assert result.tool_calls[0].function.name == "get_weather" args = result.tool_calls[0].function.arguments assert "San Francisco" in args assert "celsius" in args def test_text_before_tool_call(self, parser, mock_request): model_output = ( "Let me check the weather for you. " "call:get_weather{location:Paris}" "" ) result = parser.extract_tool_calls(model_output, mock_request) assert result.tools_called is True assert result.content == "Let me check the weather for you." def test_multiple_tool_calls(self, parser, mock_request): model_output = ( "call:get_weather{location:London}" "" "call:get_time{timezone:UTC}" "" ) result = parser.extract_tool_calls(model_output, mock_request) assert result.tools_called is True assert len(result.tool_calls) == 2 assert result.tool_calls[0].function.name == "get_weather" assert result.tool_calls[1].function.name == "get_time" class TestParseArguments: def test_empty_arguments(self, parser): result = parser._parse_arguments("") assert result == {} def test_single_string_argument(self, parser): result = parser._parse_arguments("city:Tokyo") assert result == {"city": "Tokyo"} def test_multiple_arguments(self, parser): args_str = "city:Tokyo,country:Japan" result = parser._parse_arguments(args_str) assert result == {"city": "Tokyo", "country": "Japan"} def test_numeric_argument(self, parser): result = parser._parse_arguments("count:42") assert result == {"count": 42} def test_boolean_argument(self, parser): result = parser._parse_arguments("enabled:true") assert result == {"enabled": True} def test_argument_with_spaces(self, parser): result = parser._parse_arguments("message:Hello World") assert result == {"message": "Hello World"} class TestAdjustRequest: def test_skip_special_tokens_disabled(self, parser, mock_request): mock_request.tools = [{"type": "function", "function": {"name": "test"}}] mock_request.tool_choice = "auto" mock_request.skip_special_tokens = True result = parser.adjust_request(mock_request) assert result.skip_special_tokens is False def test_skip_special_tokens_when_tool_choice_none(self, parser, mock_request): mock_request.tools = [{"type": "function", "function": {"name": "test"}}] mock_request.tool_choice = "none" mock_request.skip_special_tokens = True result = parser.adjust_request(mock_request) assert result.skip_special_tokens is True class TestBufferDeltaText: def test_regular_text_not_buffered(self, parser): result = parser._buffer_delta_text("hello") assert result == "hello" assert parser.buffered_delta_text == "" def test_complete_tag_flushed(self, parser): parser.buffered_delta_text = "") assert "" in result if __name__ == "__main__": pytest.main([__file__, "-v"])