# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast import json from json import JSONDecodeError, JSONDecoder from typing import Any, TypeAlias import partial_json_parser from openai.types.responses import ( FunctionTool, ToolChoiceFunction, ) from openai.types.responses.tool import Tool as ResponsesTool from partial_json_parser.core.options import Allow from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionToolsParam, ) from vllm.entrypoints.openai.engine.protocol import ( DeltaFunctionCall, DeltaToolCall, FunctionCall, ToolCall, ) from vllm.logger import init_logger Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool logger = init_logger(__name__) def partial_tag_overlap(text: str, tag: str) -> int: """Length of the longest prefix of *tag* that matches a suffix of *text*. E.g. text ending in ``""``. Returns 0 when there is no overlap. """ max_check = min(len(tag) - 1, len(text)) for k in range(max_check, 0, -1): if text.endswith(tag[:k]): return k return 0 def find_common_prefix(s1: str, s2: str) -> str: """ Finds a common prefix that is shared between two strings, if there is one. Order of arguments is NOT important. This function is provided as a UTILITY for extracting information from JSON generated by partial_json_parser, to help in ensuring that the right tokens are returned in streaming, so that close-quotes, close-brackets and close-braces are not returned prematurely. e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '{"fruit": "ap' """ prefix = "" min_length = min(len(s1), len(s2)) for i in range(0, min_length): if s1[i] == s2[i]: prefix += s1[i] else: break return prefix def find_common_suffix(s1: str, s2: str) -> str: """ Finds a common suffix shared between two strings, if there is one. Order of arguments is NOT important. Stops when the suffix ends OR it hits an alphanumeric character e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' """ suffix = "" min_length = min(len(s1), len(s2)) for i in range(1, min_length + 1): if s1[-i] == s2[-i] and not s1[-i].isalnum(): suffix = s1[-i] + suffix else: break return suffix def extract_intermediate_diff(curr: str, old: str) -> str: """ Given two strings, extract the difference in the middle between two strings that are known to have a common prefix and/or suffix. This function is provided as a UTILITY for extracting information from JSON generated by partial_json_parser, to help in ensuring that the right tokens are returned in streaming, so that close-quotes, close-brackets and close-braces are not returned prematurely. The order of arguments IS important - the new version of the partially-parsed JSON must be the first argument, and the secnod argument must be from the previous generation. What it returns, is tokens that should be streamed to the client. e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') -> 'ple' """ suffix = find_common_suffix(curr, old) old = old[::-1].replace(suffix[::-1], "", 1)[::-1] prefix = find_common_prefix(curr, old) diff = curr if len(suffix): diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1] if len(prefix): # replace the prefix only once in case it's mirrored diff = diff.replace(prefix, "", 1) return diff # partial_json_parser doesn't support extra data and # JSONDecoder.raw_decode doesn't support partial JSON def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]: try: return (partial_json_parser.loads(input_str, flags), len(input_str)) except JSONDecodeError as e: if "Extra data" in e.msg: dec = JSONDecoder() return dec.raw_decode(input_str) raise def is_complete_json(input_str: str) -> bool: try: json.loads(input_str) return True except JSONDecodeError: return False def consume_space(i: int, s: str) -> int: while i < len(s) and s[i].isspace(): i += 1 return i def _extract_tool_info( tool: Tool, ) -> tuple[str, dict[str, Any] | None]: if isinstance(tool, FunctionTool): return tool.name, tool.parameters elif isinstance(tool, ChatCompletionToolsParam): return tool.function.name, tool.function.parameters else: raise TypeError(f"Unsupported tool type: {type(tool)}") def _get_tool_schema_from_tool(tool: Tool) -> dict: name, params = _extract_tool_info(tool) params = params if params else {"type": "object", "properties": {}} return { "properties": { "name": {"type": "string", "enum": [name]}, "parameters": params, }, "required": ["name", "parameters"], } def _get_tool_schema_defs( tools: list[Tool], ) -> dict: all_defs: dict[str, dict[str, Any]] = {} for tool in tools: _, params = _extract_tool_info(tool) if params is None: continue defs = params.pop("$defs", {}) for def_name, def_schema in defs.items(): if def_name in all_defs and all_defs[def_name] != def_schema: raise ValueError( f"Tool definition '{def_name}' has multiple schemas, " "which is not supported." ) all_defs[def_name] = def_schema return all_defs def _get_json_schema_from_tools( tools: list[Tool], ) -> dict: json_schema = { "type": "array", "minItems": 1, "items": { "type": "object", "anyOf": [_get_tool_schema_from_tool(tool) for tool in tools], }, } json_schema_defs = _get_tool_schema_defs(tools) if json_schema_defs: json_schema["$defs"] = json_schema_defs return json_schema def get_json_schema_from_tools( tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam, tools: list[Tool] | None, ) -> str | dict | None: # tool_choice: "none" if tool_choice in ("none", None) or tools is None: return None # tool_choice: Forced Function (Responses) if (not isinstance(tool_choice, str)) and isinstance( tool_choice, ToolChoiceFunction ): tool_name = tool_choice.name tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)} if tool_name not in tool_map: raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") return tool_map[tool_name].parameters # tool_choice: Forced Function (ChatCompletion) if (not isinstance(tool_choice, str)) and isinstance( tool_choice, ChatCompletionNamedToolChoiceParam ): tool_name = tool_choice.function.name tool_map = { tool.function.name: tool for tool in tools if isinstance(tool, ChatCompletionToolsParam) } if tool_name not in tool_map: raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") return tool_map[tool_name].function.parameters # tool_choice: "required" if tool_choice == "required": return _get_json_schema_from_tools(tools) # tool_choice: "auto" return None # --------------------------------------------------------------------------- # Shared utilities for pythonic-style tool call parsers # (PythonicToolParser, Llama4PythonicToolParser, Olmo3PythonicToolParser) # --------------------------------------------------------------------------- class UnexpectedAstError(Exception): """Raised when the AST structure does not match the expected pythonic tool call format.""" pass _JSON_NAME_LITERALS = { "null": None, "true": True, "false": False, } def get_parameter_value(val: ast.expr) -> Any: """Extract a Python literal value from an AST expression node. Handles constants, dicts, lists, and JSON-style name literals (null, true, false) that some models produce instead of Python literals (None, True, False). Raises: UnexpectedAstError: If the AST node is not a supported literal type. """ if isinstance(val, ast.Constant): return val.value elif isinstance(val, ast.Dict): if not all(isinstance(k, ast.Constant) for k in val.keys): logger.warning( "Dict argument keys are not all literals: %s", ast.dump(val), ) raise UnexpectedAstError("Dict tool call arguments must have literal keys") return { k.value: get_parameter_value(v) # type: ignore for k, v in zip(val.keys, val.values) } elif isinstance(val, ast.List): return [get_parameter_value(v) for v in val.elts] elif isinstance(val, ast.Name) and val.id in _JSON_NAME_LITERALS: return _JSON_NAME_LITERALS[val.id] else: logger.warning( "Unsupported AST node type in tool call arguments: %s", ast.dump(val), ) raise UnexpectedAstError("Tool call arguments must be literals") def handle_single_tool(call: ast.Call) -> ToolCall: """Convert a single AST function call node into a ToolCall object. Raises: UnexpectedAstError: If the call node does not have a simple function name (e.g. it's an attribute access or subscript). """ if not isinstance(call.func, ast.Name): logger.warning( "Tool call has non-simple function name: %s", ast.dump(call.func), ) raise UnexpectedAstError("Invalid tool call name") function_name = call.func.id arguments = {} for keyword in call.keywords: arguments[keyword.arg] = get_parameter_value(keyword.value) return ToolCall( type="function", function=FunctionCall( name=function_name, arguments=json.dumps(arguments, ensure_ascii=False), ), ) def make_valid_python(text: str) -> tuple[str, str] | None: """Attempt to close all open brackets/quotes to make partial Python valid. Used during streaming to parse incomplete tool call expressions by appending the necessary closing characters. Returns: A tuple of (completed_text, added_suffix) if the text can be made valid, or None if the text is too incomplete to complete meaningfully (e.g. mid-parameter-name or mid-dict-key). Raises: UnexpectedAstError: If mismatched brackets or parentheses are detected. """ bracket_stack: list[str] = [] for index, char in enumerate(text): if char in {"[", "(", "{"}: bracket_stack.append(char) elif char == "]": if not bracket_stack or bracket_stack.pop() != "[": raise UnexpectedAstError("Mismatched square brackets") elif char == ")": if not bracket_stack or bracket_stack.pop() != "(": raise UnexpectedAstError("Mismatched parentheses") elif char == "}": if not bracket_stack or bracket_stack.pop() != "{": raise UnexpectedAstError("Mismatched curly braces") elif char in {"'", '"'}: if bracket_stack and bracket_stack[-1] == char: if index > 0 and text[index - 1] == "\\": pass else: bracket_stack.pop() elif bracket_stack and bracket_stack[-1] in {"'", '"'}: pass else: bracket_stack.append(char) text = text.rstrip() if text.endswith("=") or text.endswith(":"): return None if bracket_stack and bracket_stack[-1] == "{": trailing_dict_text = text[: text.rfind("{")] num_keys = trailing_dict_text.count(":") num_values = trailing_dict_text.count(",") if num_keys <= num_values: return None if bracket_stack and bracket_stack[-1] == "(": trailing_params_text = text[: text.rfind("(")] num_full_param_names = trailing_params_text.count("=") num_full_param_values = trailing_params_text.count(",") if num_full_param_names <= num_full_param_values: return None if text.endswith(","): text = text[:-1] if ( bracket_stack and bracket_stack[-1] == "[" and not text.endswith("[") and not text.endswith(")") ): return None _CLOSING = {"[": "]", "(": ")", "{": "}", "'": "'", '"': '"'} added_text = "" for char in reversed(bracket_stack): added_text += _CLOSING[char] return text + added_text, added_text def compute_tool_delta( previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str, ) -> DeltaToolCall | None: """Compute the incremental delta between previously streamed arguments and the current tool call state. Returns: A DeltaToolCall with only the new argument characters, or None if there is no difference from what was previously sent. """ new_call_args = new_call.function.arguments if withheld_suffix: if not new_call_args.endswith(withheld_suffix): msg = ( f"Tool call arguments '{new_call_args}' do not end with " f"expected withheld suffix '{withheld_suffix}'" ) logger.error(msg) raise ValueError(msg) new_call_args = new_call_args[: -len(withheld_suffix)] if not previously_sent_args: return DeltaToolCall( id=new_call.id, type="function", index=index, function=DeltaFunctionCall( name=new_call.function.name, arguments=new_call_args, ), ) arg_diff = new_call_args[len(previously_sent_args) :] return ( DeltaToolCall( id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff), ) if arg_diff else None )