diff --git a/tests/tool_parsers/test_qwen3coder_tool_parser.py b/tests/tool_parsers/test_qwen3coder_tool_parser.py index 7db1b6857..c62e95830 100644 --- a/tests/tool_parsers/test_qwen3coder_tool_parser.py +++ b/tests/tool_parsers/test_qwen3coder_tool_parser.py @@ -5,6 +5,7 @@ import json from collections.abc import Generator import pytest +from openai.types.responses.function_tool import FunctionTool from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, @@ -49,41 +50,62 @@ def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, req return qwen3_xml_tool_parser -@pytest.fixture -def sample_tools(): - return [ - ChatCompletionToolsParam( - type="function", - function={ - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "city": {"type": "string", "description": "The city name"}, - "state": {"type": "string", "description": "The state code"}, - "unit": {"type": "string", "enum": ["fahrenheit", "celsius"]}, - }, - "required": ["city", "state"], +WEATHER_PARAMS = { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city name"}, + "state": {"type": "string", "description": "The state code"}, + "unit": {"type": "string", "enum": ["fahrenheit", "celsius"]}, + }, + "required": ["city", "state"], +} + +AREA_PARAMS = { + "type": "object", + "properties": { + "shape": {"type": "string"}, + "dimensions": {"type": "object"}, + "precision": {"type": "integer"}, + }, +} + + +@pytest.fixture(params=["chat_completion", "responses_api"]) +def sample_tools(request): + if request.param == "chat_completion": + return [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": WEATHER_PARAMS, }, - }, - ), - ChatCompletionToolsParam( - type="function", - function={ - "name": "calculate_area", - "description": "Calculate area of a shape", - "parameters": { - "type": "object", - "properties": { - "shape": {"type": "string"}, - "dimensions": {"type": "object"}, - "precision": {"type": "integer"}, - }, + ), + ChatCompletionToolsParam( + type="function", + function={ + "name": "calculate_area", + "description": "Calculate area of a shape", + "parameters": AREA_PARAMS, }, - }, - ), - ] + ), + ] + else: + return [ + FunctionTool( + type="function", + name="get_current_weather", + description="Get the current weather", + parameters=WEATHER_PARAMS, + ), + FunctionTool( + type="function", + name="calculate_area", + description="Calculate area of a shape", + parameters=AREA_PARAMS, + ), + ] def assert_tool_calls( @@ -337,12 +359,11 @@ circle ) def test_extract_tool_calls( qwen3_tool_parser_parametrized, - sample_tools, model_output, expected_tool_calls, expected_content, ): - request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[]) extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( model_output, request=request ) @@ -354,7 +375,7 @@ def test_extract_tool_calls( def test_extract_tool_calls_fallback_no_tags( - qwen3_tool_parser_parametrized, sample_tools + qwen3_tool_parser_parametrized, ): """Test fallback parsing when XML tags are missing""" model_output = """ @@ -366,7 +387,7 @@ TX """ - request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[]) extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( model_output, request=request ) @@ -607,13 +628,12 @@ circle def test_extract_tool_calls_streaming( qwen3_tool_parser_parametrized, qwen3_tokenizer, - sample_tools, model_output, expected_tool_calls, expected_content, ): """Test incremental streaming behavior including typed parameters""" - request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[]) other_content = "" tool_states = {} # Track state per tool index @@ -683,7 +703,7 @@ def test_extract_tool_calls_streaming( def test_extract_tool_calls_missing_closing_parameter_tag( - qwen3_tool_parser_parametrized, sample_tools + qwen3_tool_parser_parametrized, ): """Test handling of missing closing tag""" # Using get_current_weather from sample_tools but with malformed XML @@ -701,7 +721,7 @@ fahrenheit """ - request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[]) extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( model_output, request=request ) @@ -725,7 +745,7 @@ fahrenheit def test_extract_tool_calls_streaming_missing_closing_tag( - qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools + qwen3_tool_parser_parametrized, qwen3_tokenizer ): """Test streaming with missing closing tag""" # Using get_current_weather from sample_tools but with malformed XML @@ -743,7 +763,7 @@ fahrenheit """ - request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[]) other_content = "" tool_states = {} @@ -800,7 +820,7 @@ fahrenheit def test_extract_tool_calls_streaming_incremental( - qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools + qwen3_tool_parser_parametrized, qwen3_tokenizer ): """Test that streaming is truly incremental""" model_output = """I'll check the weather. @@ -814,7 +834,7 @@ TX """ - request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[]) chunks = [] for delta_message in stream_delta_message_generator( @@ -897,7 +917,7 @@ def test_extract_tool_calls_complex_type_with_single_quote( def test_extract_tool_calls_streaming_missing_opening_tag( - qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools + qwen3_tool_parser_parametrized, qwen3_tokenizer ): """Test streaming with missing opening tag @@ -919,7 +939,7 @@ fahrenheit """ - request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[]) other_content = "" tool_states = {} @@ -976,7 +996,7 @@ fahrenheit assert args["unit"] == "fahrenheit" -def test_malformed_xml_no_gt_delimiter(qwen3_tool_parser, sample_tools): +def test_malformed_xml_no_gt_delimiter(qwen3_tool_parser): """Regression: malformed XML without '>' must not crash (PR #36774).""" model_output = ( "\n" @@ -986,14 +1006,14 @@ def test_malformed_xml_no_gt_delimiter(qwen3_tool_parser, sample_tools): "" ) - request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[]) result = qwen3_tool_parser.extract_tool_calls(model_output, request=request) assert result is not None assert isinstance(result.tool_calls, list) assert all(tc is not None for tc in result.tool_calls) -def test_none_tool_calls_filtered(qwen3_tool_parser, sample_tools): +def test_none_tool_calls_filtered(qwen3_tool_parser): """Regression: None tool calls filtered from output (PR #36774).""" model_output = ( "\n" @@ -1008,7 +1028,7 @@ def test_none_tool_calls_filtered(qwen3_tool_parser, sample_tools): "" ) - request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[]) result = qwen3_tool_parser.extract_tool_calls(model_output, request=request) assert all(tc is not None for tc in result.tool_calls) assert result.tools_called @@ -1058,11 +1078,9 @@ def test_anyof_parameter_not_double_encoded(qwen3_tokenizer): assert args["data"] == {"key": "value", "count": 42} -def test_streaming_multi_param_single_chunk( - qwen3_tool_parser, qwen3_tokenizer, sample_tools -): +def test_streaming_multi_param_single_chunk(qwen3_tool_parser, qwen3_tokenizer): """Regression: speculative decode delivering multiple params at once (PR #35615).""" - request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[]) deltas = [ "", diff --git a/vllm/tool_parsers/qwen3coder_tool_parser.py b/vllm/tool_parsers/qwen3coder_tool_parser.py index ea25ea2be..7b089ceff 100644 --- a/vllm/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/tool_parsers/qwen3coder_tool_parser.py @@ -25,6 +25,7 @@ from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) +from vllm.tool_parsers.utils import find_tool_properties logger = init_logger(__name__) @@ -109,28 +110,6 @@ class Qwen3CoderToolParser(ToolParser): self.accumulated_params = {} self.streaming_request = None - def _get_arguments_config(self, func_name: str, tools: list[Tool] | None) -> dict: - """Extract argument configuration for a function.""" - if tools is None: - return {} - for config in tools: - if not hasattr(config, "type") or not ( - hasattr(config, "function") and hasattr(config.function, "name") - ): - continue - if config.type == "function" and config.function.name == func_name: - if not hasattr(config.function, "parameters"): - return {} - params = config.function.parameters - if isinstance(params, dict) and "properties" in params: - return params["properties"] - elif isinstance(params, dict): - return params - else: - return {} - logger.debug("Tool '%s' is not defined in the tools list.", func_name) - return {} - def _convert_param_value( self, param_value: str, param_name: str, param_config: dict, func_name: str ) -> Any: @@ -243,16 +222,14 @@ class Qwen3CoderToolParser(ToolParser): ) return param_value - def _parse_xml_function_call( - self, function_call_str: str, tools: list[Tool] | None - ) -> ToolCall | None: + def _parse_xml_function_call(self, function_call_str: str) -> ToolCall | None: # Extract function name end_index = function_call_str.find(">") # If there's no ">" character, this is not a valid xml function call if end_index == -1: return None function_name = function_call_str[:end_index] - param_config = self._get_arguments_config(function_name, tools) + param_config = find_tool_properties(self.tools, function_name) parameters = function_call_str[end_index + 1 :] param_dict = {} for match_text in self.tool_call_parameter_regex.findall(parameters): @@ -314,7 +291,7 @@ class Qwen3CoderToolParser(ToolParser): ) tool_calls = [ - self._parse_xml_function_call(function_call_str, self.tools) + self._parse_xml_function_call(function_call_str) for function_call_str in function_calls ] # Populate prev_tool_call_arr for serving layer to set finish_reason @@ -605,9 +582,8 @@ class Qwen3CoderToolParser(ToolParser): self.current_param_name = current_param_name self.accumulated_params[current_param_name] = param_value - param_config = self._get_arguments_config( - self.current_function_name or "", - self.tools, + param_config = find_tool_properties( + self.tools, self.current_function_name or "" ) converted_value = self._convert_param_value( @@ -666,7 +642,6 @@ class Qwen3CoderToolParser(ToolParser): try: parsed_tool = self._parse_xml_function_call( func_content, - self.tools, ) if parsed_tool and self.current_tool_index < len( self.prev_tool_call_arr diff --git a/vllm/tool_parsers/qwen3xml_tool_parser.py b/vllm/tool_parsers/qwen3xml_tool_parser.py index 6e28c82b1..4ecb96668 100644 --- a/vllm/tool_parsers/qwen3xml_tool_parser.py +++ b/vllm/tool_parsers/qwen3xml_tool_parser.py @@ -26,6 +26,7 @@ from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) +from vllm.tool_parsers.utils import find_tool_properties logger = init_logger(__name__) @@ -1000,33 +1001,11 @@ class StreamingXMLToolCallParser: if not self.tools or not self.current_function_name: return "string" - for tool in self.tools: - if not hasattr(tool, "type") or not ( - hasattr(tool, "function") and hasattr(tool.function, "name") - ): - continue - if ( - tool.type == "function" - and tool.function.name == self.current_function_name - ): - if not hasattr(tool.function, "parameters"): - return "string" - params = tool.function.parameters - if isinstance(params, dict) and "properties" in params: - properties = params["properties"] - if param_name in properties and isinstance( - properties[param_name], dict - ): - return self.repair_param_type( - str(properties[param_name].get("type", "string")) - ) - elif isinstance(params, dict) and param_name in params: - param_config = params[param_name] - if isinstance(param_config, dict): - return self.repair_param_type( - str(param_config.get("type", "string")) - ) - break + properties = find_tool_properties(self.tools, self.current_function_name) + if param_name in properties and isinstance(properties[param_name], dict): + return self.repair_param_type( + str(properties[param_name].get("type", "string")) + ) return "string" def repair_param_type(self, param_type: str) -> str: diff --git a/vllm/tool_parsers/utils.py b/vllm/tool_parsers/utils.py index 82b7eaaab..b25198924 100644 --- a/vllm/tool_parsers/utils.py +++ b/vllm/tool_parsers/utils.py @@ -142,6 +142,20 @@ def _extract_tool_info( raise TypeError(f"Unsupported tool type: {type(tool)}") +def find_tool_properties( + tools: list[Tool] | None, + tool_name: str, +) -> dict[str, Any]: + """Find a tool by name and return its properties dict, or {}.""" + if not tools: + return {} + for tool in tools: + name, params = _extract_tool_info(tool) + if name == tool_name: + return (params or {}).get("properties", {}) + return {} + + def _get_tool_schema_from_tool(tool: Tool) -> dict: name, params = _extract_tool_info(tool) params = params if params else {"type": "object", "properties": {}}