439 lines
14 KiB
Python
439 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import ast
|
|
import json
|
|
from json import JSONDecodeError, JSONDecoder
|
|
from typing import Any, TypeAlias
|
|
|
|
import partial_json_parser
|
|
from openai.types.responses import (
|
|
FunctionTool,
|
|
ToolChoiceFunction,
|
|
)
|
|
from openai.types.responses.tool import Tool as ResponsesTool
|
|
from partial_json_parser.core.options import Allow
|
|
|
|
from vllm.entrypoints.openai.chat_completion.protocol import (
|
|
ChatCompletionNamedToolChoiceParam,
|
|
ChatCompletionToolsParam,
|
|
)
|
|
from vllm.entrypoints.openai.engine.protocol import (
|
|
DeltaFunctionCall,
|
|
DeltaToolCall,
|
|
FunctionCall,
|
|
ToolCall,
|
|
)
|
|
from vllm.logger import init_logger
|
|
|
|
Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def partial_tag_overlap(text: str, tag: str) -> int:
|
|
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
|
|
|
|
E.g. text ending in ``"<tool_"`` returns 6 when tag is ``"<tool_call>"``.
|
|
Returns 0 when there is no overlap.
|
|
"""
|
|
max_check = min(len(tag) - 1, len(text))
|
|
for k in range(max_check, 0, -1):
|
|
if text.endswith(tag[:k]):
|
|
return k
|
|
return 0
|
|
|
|
|
|
def find_common_prefix(s1: str, s2: str) -> str:
|
|
"""
|
|
Finds a common prefix that is shared between two strings, if there is one.
|
|
Order of arguments is NOT important.
|
|
|
|
This function is provided as a UTILITY for extracting information from JSON
|
|
generated by partial_json_parser, to help in ensuring that the right tokens
|
|
are returned in streaming, so that close-quotes, close-brackets and
|
|
close-braces are not returned prematurely.
|
|
|
|
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
|
'{"fruit": "ap'
|
|
"""
|
|
prefix = ""
|
|
min_length = min(len(s1), len(s2))
|
|
for i in range(0, min_length):
|
|
if s1[i] == s2[i]:
|
|
prefix += s1[i]
|
|
else:
|
|
break
|
|
return prefix
|
|
|
|
|
|
def find_common_suffix(s1: str, s2: str) -> str:
|
|
"""
|
|
Finds a common suffix shared between two strings, if there is one. Order of
|
|
arguments is NOT important.
|
|
Stops when the suffix ends OR it hits an alphanumeric character
|
|
|
|
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
|
"""
|
|
suffix = ""
|
|
min_length = min(len(s1), len(s2))
|
|
for i in range(1, min_length + 1):
|
|
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
|
suffix = s1[-i] + suffix
|
|
else:
|
|
break
|
|
return suffix
|
|
|
|
|
|
def extract_intermediate_diff(curr: str, old: str) -> str:
|
|
"""
|
|
Given two strings, extract the difference in the middle between two strings
|
|
that are known to have a common prefix and/or suffix.
|
|
|
|
This function is provided as a UTILITY for extracting information from JSON
|
|
generated by partial_json_parser, to help in ensuring that the right tokens
|
|
are returned in streaming, so that close-quotes, close-brackets and
|
|
close-braces are not returned prematurely. The order of arguments IS
|
|
important - the new version of the partially-parsed JSON must be the first
|
|
argument, and the secnod argument must be from the previous generation.
|
|
|
|
What it returns, is tokens that should be streamed to the client.
|
|
|
|
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
|
-> 'ple'
|
|
|
|
"""
|
|
suffix = find_common_suffix(curr, old)
|
|
|
|
old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
|
|
prefix = find_common_prefix(curr, old)
|
|
diff = curr
|
|
if len(suffix):
|
|
diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
|
|
|
|
if len(prefix):
|
|
# replace the prefix only once in case it's mirrored
|
|
diff = diff.replace(prefix, "", 1)
|
|
|
|
return diff
|
|
|
|
|
|
# partial_json_parser doesn't support extra data and
|
|
# JSONDecoder.raw_decode doesn't support partial JSON
|
|
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
|
|
try:
|
|
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
|
except JSONDecodeError as e:
|
|
if "Extra data" in e.msg:
|
|
dec = JSONDecoder()
|
|
return dec.raw_decode(input_str)
|
|
raise
|
|
|
|
|
|
def is_complete_json(input_str: str) -> bool:
|
|
try:
|
|
json.loads(input_str)
|
|
return True
|
|
except JSONDecodeError:
|
|
return False
|
|
|
|
|
|
def consume_space(i: int, s: str) -> int:
|
|
while i < len(s) and s[i].isspace():
|
|
i += 1
|
|
return i
|
|
|
|
|
|
def _extract_tool_info(
|
|
tool: Tool,
|
|
) -> tuple[str, dict[str, Any] | None]:
|
|
if isinstance(tool, FunctionTool):
|
|
return tool.name, tool.parameters
|
|
elif isinstance(tool, ChatCompletionToolsParam):
|
|
return tool.function.name, tool.function.parameters
|
|
else:
|
|
raise TypeError(f"Unsupported tool type: {type(tool)}")
|
|
|
|
|
|
def _get_tool_schema_from_tool(tool: Tool) -> dict:
|
|
name, params = _extract_tool_info(tool)
|
|
params = params if params else {"type": "object", "properties": {}}
|
|
return {
|
|
"properties": {
|
|
"name": {"type": "string", "enum": [name]},
|
|
"parameters": params,
|
|
},
|
|
"required": ["name", "parameters"],
|
|
}
|
|
|
|
|
|
def _get_tool_schema_defs(
|
|
tools: list[Tool],
|
|
) -> dict:
|
|
all_defs: dict[str, dict[str, Any]] = {}
|
|
for tool in tools:
|
|
_, params = _extract_tool_info(tool)
|
|
if params is None:
|
|
continue
|
|
defs = params.pop("$defs", {})
|
|
for def_name, def_schema in defs.items():
|
|
if def_name in all_defs and all_defs[def_name] != def_schema:
|
|
raise ValueError(
|
|
f"Tool definition '{def_name}' has multiple schemas, "
|
|
"which is not supported."
|
|
)
|
|
all_defs[def_name] = def_schema
|
|
return all_defs
|
|
|
|
|
|
def _get_json_schema_from_tools(
|
|
tools: list[Tool],
|
|
) -> dict:
|
|
json_schema = {
|
|
"type": "array",
|
|
"minItems": 1,
|
|
"items": {
|
|
"type": "object",
|
|
"anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
|
|
},
|
|
}
|
|
json_schema_defs = _get_tool_schema_defs(tools)
|
|
if json_schema_defs:
|
|
json_schema["$defs"] = json_schema_defs
|
|
return json_schema
|
|
|
|
|
|
def get_json_schema_from_tools(
|
|
tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
|
|
tools: list[Tool] | None,
|
|
) -> str | dict | None:
|
|
# tool_choice: "none"
|
|
if tool_choice in ("none", None) or tools is None:
|
|
return None
|
|
# tool_choice: Forced Function (Responses)
|
|
if (not isinstance(tool_choice, str)) and isinstance(
|
|
tool_choice, ToolChoiceFunction
|
|
):
|
|
tool_name = tool_choice.name
|
|
tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
|
|
if tool_name not in tool_map:
|
|
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
|
return tool_map[tool_name].parameters
|
|
# tool_choice: Forced Function (ChatCompletion)
|
|
if (not isinstance(tool_choice, str)) and isinstance(
|
|
tool_choice, ChatCompletionNamedToolChoiceParam
|
|
):
|
|
tool_name = tool_choice.function.name
|
|
tool_map = {
|
|
tool.function.name: tool
|
|
for tool in tools
|
|
if isinstance(tool, ChatCompletionToolsParam)
|
|
}
|
|
if tool_name not in tool_map:
|
|
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
|
return tool_map[tool_name].function.parameters
|
|
# tool_choice: "required"
|
|
if tool_choice == "required":
|
|
return _get_json_schema_from_tools(tools)
|
|
# tool_choice: "auto"
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared utilities for pythonic-style tool call parsers
|
|
# (PythonicToolParser, Llama4PythonicToolParser, Olmo3PythonicToolParser)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class UnexpectedAstError(Exception):
|
|
"""Raised when the AST structure does not match the expected
|
|
pythonic tool call format."""
|
|
|
|
pass
|
|
|
|
|
|
_JSON_NAME_LITERALS = {
|
|
"null": None,
|
|
"true": True,
|
|
"false": False,
|
|
}
|
|
|
|
|
|
def get_parameter_value(val: ast.expr) -> Any:
|
|
"""Extract a Python literal value from an AST expression node.
|
|
|
|
Handles constants, dicts, lists, and JSON-style name literals
|
|
(null, true, false) that some models produce instead of Python
|
|
literals (None, True, False).
|
|
|
|
Raises:
|
|
UnexpectedAstError: If the AST node is not a supported literal type.
|
|
"""
|
|
if isinstance(val, ast.Constant):
|
|
return val.value
|
|
elif isinstance(val, ast.Dict):
|
|
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
|
logger.warning(
|
|
"Dict argument keys are not all literals: %s",
|
|
ast.dump(val),
|
|
)
|
|
raise UnexpectedAstError("Dict tool call arguments must have literal keys")
|
|
return {
|
|
k.value: get_parameter_value(v) # type: ignore
|
|
for k, v in zip(val.keys, val.values)
|
|
}
|
|
elif isinstance(val, ast.List):
|
|
return [get_parameter_value(v) for v in val.elts]
|
|
elif isinstance(val, ast.Name) and val.id in _JSON_NAME_LITERALS:
|
|
return _JSON_NAME_LITERALS[val.id]
|
|
else:
|
|
logger.warning(
|
|
"Unsupported AST node type in tool call arguments: %s",
|
|
ast.dump(val),
|
|
)
|
|
raise UnexpectedAstError("Tool call arguments must be literals")
|
|
|
|
|
|
def handle_single_tool(call: ast.Call) -> ToolCall:
|
|
"""Convert a single AST function call node into a ToolCall object.
|
|
|
|
Raises:
|
|
UnexpectedAstError: If the call node does not have a simple
|
|
function name (e.g. it's an attribute access or subscript).
|
|
"""
|
|
if not isinstance(call.func, ast.Name):
|
|
logger.warning(
|
|
"Tool call has non-simple function name: %s",
|
|
ast.dump(call.func),
|
|
)
|
|
raise UnexpectedAstError("Invalid tool call name")
|
|
function_name = call.func.id
|
|
arguments = {}
|
|
for keyword in call.keywords:
|
|
arguments[keyword.arg] = get_parameter_value(keyword.value)
|
|
return ToolCall(
|
|
type="function",
|
|
function=FunctionCall(
|
|
name=function_name,
|
|
arguments=json.dumps(arguments, ensure_ascii=False),
|
|
),
|
|
)
|
|
|
|
|
|
def make_valid_python(text: str) -> tuple[str, str] | None:
|
|
"""Attempt to close all open brackets/quotes to make partial Python valid.
|
|
|
|
Used during streaming to parse incomplete tool call expressions by
|
|
appending the necessary closing characters.
|
|
|
|
Returns:
|
|
A tuple of (completed_text, added_suffix) if the text can be
|
|
made valid, or None if the text is too incomplete to complete
|
|
meaningfully (e.g. mid-parameter-name or mid-dict-key).
|
|
|
|
Raises:
|
|
UnexpectedAstError: If mismatched brackets or parentheses
|
|
are detected.
|
|
"""
|
|
bracket_stack: list[str] = []
|
|
for index, char in enumerate(text):
|
|
if char in {"[", "(", "{"}:
|
|
bracket_stack.append(char)
|
|
elif char == "]":
|
|
if not bracket_stack or bracket_stack.pop() != "[":
|
|
raise UnexpectedAstError("Mismatched square brackets")
|
|
elif char == ")":
|
|
if not bracket_stack or bracket_stack.pop() != "(":
|
|
raise UnexpectedAstError("Mismatched parentheses")
|
|
elif char == "}":
|
|
if not bracket_stack or bracket_stack.pop() != "{":
|
|
raise UnexpectedAstError("Mismatched curly braces")
|
|
elif char in {"'", '"'}:
|
|
if bracket_stack and bracket_stack[-1] == char:
|
|
if index > 0 and text[index - 1] == "\\":
|
|
pass
|
|
else:
|
|
bracket_stack.pop()
|
|
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
|
pass
|
|
else:
|
|
bracket_stack.append(char)
|
|
|
|
text = text.rstrip()
|
|
if text.endswith("=") or text.endswith(":"):
|
|
return None
|
|
if bracket_stack and bracket_stack[-1] == "{":
|
|
trailing_dict_text = text[: text.rfind("{")]
|
|
num_keys = trailing_dict_text.count(":")
|
|
num_values = trailing_dict_text.count(",")
|
|
if num_keys <= num_values:
|
|
return None
|
|
if bracket_stack and bracket_stack[-1] == "(":
|
|
trailing_params_text = text[: text.rfind("(")]
|
|
num_full_param_names = trailing_params_text.count("=")
|
|
num_full_param_values = trailing_params_text.count(",")
|
|
if num_full_param_names <= num_full_param_values:
|
|
return None
|
|
if text.endswith(","):
|
|
text = text[:-1]
|
|
if (
|
|
bracket_stack
|
|
and bracket_stack[-1] == "["
|
|
and not text.endswith("[")
|
|
and not text.endswith(")")
|
|
):
|
|
return None
|
|
|
|
_CLOSING = {"[": "]", "(": ")", "{": "}", "'": "'", '"': '"'}
|
|
added_text = ""
|
|
for char in reversed(bracket_stack):
|
|
added_text += _CLOSING[char]
|
|
|
|
return text + added_text, added_text
|
|
|
|
|
|
def compute_tool_delta(
|
|
previously_sent_args: str,
|
|
new_call: ToolCall,
|
|
index: int,
|
|
withheld_suffix: str,
|
|
) -> DeltaToolCall | None:
|
|
"""Compute the incremental delta between previously streamed arguments
|
|
and the current tool call state.
|
|
|
|
Returns:
|
|
A DeltaToolCall with only the new argument characters, or None
|
|
if there is no difference from what was previously sent.
|
|
"""
|
|
new_call_args = new_call.function.arguments
|
|
if withheld_suffix:
|
|
if not new_call_args.endswith(withheld_suffix):
|
|
msg = (
|
|
f"Tool call arguments '{new_call_args}' do not end with "
|
|
f"expected withheld suffix '{withheld_suffix}'"
|
|
)
|
|
logger.error(msg)
|
|
raise ValueError(msg)
|
|
new_call_args = new_call_args[: -len(withheld_suffix)]
|
|
if not previously_sent_args:
|
|
return DeltaToolCall(
|
|
id=new_call.id,
|
|
type="function",
|
|
index=index,
|
|
function=DeltaFunctionCall(
|
|
name=new_call.function.name,
|
|
arguments=new_call_args,
|
|
),
|
|
)
|
|
|
|
arg_diff = new_call_args[len(previously_sent_args) :]
|
|
return (
|
|
DeltaToolCall(
|
|
id=None,
|
|
index=index,
|
|
function=DeltaFunctionCall(arguments=arg_diff),
|
|
)
|
|
if arg_diff
|
|
else None
|
|
)
|