[Misc] Refactored 5 duplicate helper functions that were copied-pasted across multiple parsers (#36436)
Signed-off-by: Taneem Ibrahim <taneem.ibrahim@gmail.com>
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -13,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import (
|
||||
UnexpectedAstError,
|
||||
compute_tool_delta,
|
||||
handle_single_tool,
|
||||
make_valid_python,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Llama4PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Toolcall parser for Llama4 that produce tool calls in a pythonic style
|
||||
@@ -103,15 +100,13 @@ class Llama4PythonicToolParser(ToolParser):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
raise UnexpectedAstError("Tool output must be a list of function calls")
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
@@ -140,7 +135,7 @@ class Llama4PythonicToolParser(ToolParser):
|
||||
current_text = current_text[len("<|python_start|>") :]
|
||||
if current_text.endswith("<|python_end|>"):
|
||||
current_text = current_text[: current_text.rfind("<|python_end|>")]
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
valid_and_added_text = make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
@@ -150,11 +145,9 @@ class Llama4PythonicToolParser(ToolParser):
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
raise UnexpectedAstError("Tool output must be a list of function calls")
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
@@ -180,7 +173,7 @@ class Llama4PythonicToolParser(ToolParser):
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(
|
||||
delta = compute_tool_delta(
|
||||
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
|
||||
)
|
||||
|
||||
@@ -214,130 +207,3 @@ class Llama4PythonicToolParser(ToolParser):
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=function_name, arguments=json.dumps(arguments)),
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> tuple[str, str] | None:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[: text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[: text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if (
|
||||
bracket_stack
|
||||
and bracket_stack[-1] == "["
|
||||
and not text.endswith("[")
|
||||
and not text.endswith(")")
|
||||
):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(
|
||||
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
|
||||
) -> DeltaToolCall | None:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[: -len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(
|
||||
id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args) :]
|
||||
return (
|
||||
DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
|
||||
)
|
||||
if arg_diff
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -13,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import (
|
||||
UnexpectedAstError,
|
||||
compute_tool_delta,
|
||||
handle_single_tool,
|
||||
make_valid_python,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Olmo3PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Olmo 3 models that produce tool calls as
|
||||
@@ -113,15 +110,13 @@ class Olmo3PythonicToolParser(ToolParser):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
raise UnexpectedAstError("Tool output must be a list of function calls")
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
@@ -151,7 +146,7 @@ class Olmo3PythonicToolParser(ToolParser):
|
||||
if current_text.endswith("</function_calls>"):
|
||||
current_text = current_text[: -len("</function_calls>")]
|
||||
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
valid_and_added_text = make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
@@ -166,11 +161,11 @@ class Olmo3PythonicToolParser(ToolParser):
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
raise _UnexpectedAstError(
|
||||
raise UnexpectedAstError(
|
||||
"Tool output must be a sequence of newline-separated calls"
|
||||
)
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
@@ -194,7 +189,7 @@ class Olmo3PythonicToolParser(ToolParser):
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(
|
||||
delta = compute_tool_delta(
|
||||
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
|
||||
)
|
||||
|
||||
@@ -228,141 +223,3 @@ class Olmo3PythonicToolParser(ToolParser):
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
# The model may return function calls where the values are null/true/false
|
||||
# because the system prompt has API description in json.
|
||||
elif isinstance(val, ast.Name) and val.id in ["null", "true", "false"]:
|
||||
if val.id == "null":
|
||||
return None
|
||||
elif val.id == "true":
|
||||
return True
|
||||
elif val.id == "false":
|
||||
return False
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> tuple[str, str] | None:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[: text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[: text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if (
|
||||
bracket_stack
|
||||
and bracket_stack[-1] == "["
|
||||
and not text.endswith("[")
|
||||
and not text.endswith(")")
|
||||
):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(
|
||||
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
|
||||
) -> DeltaToolCall | None:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[: -len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(
|
||||
id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args) :]
|
||||
return (
|
||||
DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
|
||||
)
|
||||
if arg_diff
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -14,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import (
|
||||
UnexpectedAstError,
|
||||
compute_tool_delta,
|
||||
handle_single_tool,
|
||||
make_valid_python,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for models that produce tool calls in a pythonic style,
|
||||
@@ -99,15 +95,13 @@ class PythonicToolParser(ToolParser):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
raise UnexpectedAstError("Tool output must be a list of function calls")
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
@@ -129,7 +123,7 @@ class PythonicToolParser(ToolParser):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
valid_and_added_text = make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
@@ -139,11 +133,9 @@ class PythonicToolParser(ToolParser):
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts
|
||||
):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls"
|
||||
)
|
||||
raise UnexpectedAstError("Tool output must be a list of function calls")
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
@@ -169,7 +161,7 @@ class PythonicToolParser(ToolParser):
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(
|
||||
delta = compute_tool_delta(
|
||||
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
|
||||
)
|
||||
|
||||
@@ -203,132 +195,3 @@ class PythonicToolParser(ToolParser):
|
||||
"Skipping chunk as a result of tool streaming extraction error"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> tuple[str, str] | None:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[: text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[: text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if (
|
||||
bracket_stack
|
||||
and bracket_stack[-1] == "["
|
||||
and not text.endswith("[")
|
||||
and not text.endswith(")")
|
||||
):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(
|
||||
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
|
||||
) -> DeltaToolCall | None:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[: -len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(
|
||||
id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args) :]
|
||||
return (
|
||||
DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
|
||||
)
|
||||
if arg_diff
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any
|
||||
@@ -17,6 +18,15 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaToolCall,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
@@ -212,3 +222,202 @@ def get_json_schema_from_tools(
|
||||
return _get_json_schema_from_tools(tools)
|
||||
# tool_choice: "auto"
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared utilities for pythonic-style tool call parsers
|
||||
# (PythonicToolParser, Llama4PythonicToolParser, Olmo3PythonicToolParser)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class UnexpectedAstError(Exception):
|
||||
"""Raised when the AST structure does not match the expected
|
||||
pythonic tool call format."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
_JSON_NAME_LITERALS = {
|
||||
"null": None,
|
||||
"true": True,
|
||||
"false": False,
|
||||
}
|
||||
|
||||
|
||||
def get_parameter_value(val: ast.expr) -> Any:
|
||||
"""Extract a Python literal value from an AST expression node.
|
||||
|
||||
Handles constants, dicts, lists, and JSON-style name literals
|
||||
(null, true, false) that some models produce instead of Python
|
||||
literals (None, True, False).
|
||||
|
||||
Raises:
|
||||
UnexpectedAstError: If the AST node is not a supported literal type.
|
||||
"""
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
logger.warning(
|
||||
"Dict argument keys are not all literals: %s",
|
||||
ast.dump(val),
|
||||
)
|
||||
raise UnexpectedAstError("Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [get_parameter_value(v) for v in val.elts]
|
||||
elif isinstance(val, ast.Name) and val.id in _JSON_NAME_LITERALS:
|
||||
return _JSON_NAME_LITERALS[val.id]
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported AST node type in tool call arguments: %s",
|
||||
ast.dump(val),
|
||||
)
|
||||
raise UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
"""Convert a single AST function call node into a ToolCall object.
|
||||
|
||||
Raises:
|
||||
UnexpectedAstError: If the call node does not have a simple
|
||||
function name (e.g. it's an attribute access or subscript).
|
||||
"""
|
||||
if not isinstance(call.func, ast.Name):
|
||||
logger.warning(
|
||||
"Tool call has non-simple function name: %s",
|
||||
ast.dump(call.func),
|
||||
)
|
||||
raise UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=json.dumps(arguments, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def make_valid_python(text: str) -> tuple[str, str] | None:
|
||||
"""Attempt to close all open brackets/quotes to make partial Python valid.
|
||||
|
||||
Used during streaming to parse incomplete tool call expressions by
|
||||
appending the necessary closing characters.
|
||||
|
||||
Returns:
|
||||
A tuple of (completed_text, added_suffix) if the text can be
|
||||
made valid, or None if the text is too incomplete to complete
|
||||
meaningfully (e.g. mid-parameter-name or mid-dict-key).
|
||||
|
||||
Raises:
|
||||
UnexpectedAstError: If mismatched brackets or parentheses
|
||||
are detected.
|
||||
"""
|
||||
bracket_stack: list[str] = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[: text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[: text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if (
|
||||
bracket_stack
|
||||
and bracket_stack[-1] == "["
|
||||
and not text.endswith("[")
|
||||
and not text.endswith(")")
|
||||
):
|
||||
return None
|
||||
|
||||
_CLOSING = {"[": "]", "(": ")", "{": "}", "'": "'", '"': '"'}
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
added_text += _CLOSING[char]
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def compute_tool_delta(
|
||||
previously_sent_args: str,
|
||||
new_call: ToolCall,
|
||||
index: int,
|
||||
withheld_suffix: str,
|
||||
) -> DeltaToolCall | None:
|
||||
"""Compute the incremental delta between previously streamed arguments
|
||||
and the current tool call state.
|
||||
|
||||
Returns:
|
||||
A DeltaToolCall with only the new argument characters, or None
|
||||
if there is no difference from what was previously sent.
|
||||
"""
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
if not new_call_args.endswith(withheld_suffix):
|
||||
msg = (
|
||||
f"Tool call arguments '{new_call_args}' do not end with "
|
||||
f"expected withheld suffix '{withheld_suffix}'"
|
||||
)
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
new_call_args = new_call_args[: -len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(
|
||||
id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args) :]
|
||||
return (
|
||||
DeltaToolCall(
|
||||
id=None,
|
||||
index=index,
|
||||
function=DeltaFunctionCall(arguments=arg_diff),
|
||||
)
|
||||
if arg_diff
|
||||
else None
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user