[Misc] Refactored 5 duplicate helper functions that were copied-pasted across multiple parsers (#36436)

Signed-off-by: Taneem Ibrahim <taneem.ibrahim@gmail.com>
This commit is contained in:
Taneem Ibrahim
2026-03-09 14:14:11 -04:00
committed by GitHub
parent 4b87ffbefb
commit 8d6b3d5dda
4 changed files with 247 additions and 452 deletions

View File

@@ -1,9 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from transformers import PreTrainedTokenizerBase
@@ -13,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.tool_parsers.utils import (
UnexpectedAstError,
compute_tool_delta,
handle_single_tool,
make_valid_python,
)
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
class Llama4PythonicToolParser(ToolParser):
"""
Toolcall parser for Llama4 that produce tool calls in a pythonic style
@@ -103,15 +100,13 @@ class Llama4PythonicToolParser(ToolParser):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None,
)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
raise UnexpectedAstError("Tool output must be a list of function calls")
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
@@ -140,7 +135,7 @@ class Llama4PythonicToolParser(ToolParser):
current_text = current_text[len("<|python_start|>") :]
if current_text.endswith("<|python_end|>"):
current_text = current_text[: current_text.rfind("<|python_end|>")]
valid_and_added_text = _make_valid_python(current_text)
valid_and_added_text = make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
@@ -150,11 +145,9 @@ class Llama4PythonicToolParser(ToolParser):
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts
):
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
raise UnexpectedAstError("Tool output must be a list of function calls")
tool_calls = [
_handle_single_tool(e) # type: ignore
handle_single_tool(e) # type: ignore
for e in parsed.elts
]
@@ -180,7 +173,7 @@ class Llama4PythonicToolParser(ToolParser):
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(
delta = compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
)
@@ -214,130 +207,3 @@ class Llama4PythonicToolParser(ToolParser):
"Skipping chunk as a result of tool streaming extraction error"
)
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(name=function_name, arguments=json.dumps(arguments)),
)
def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)

View File

@@ -1,9 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from transformers import PreTrainedTokenizerBase
@@ -13,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.tool_parsers.utils import (
UnexpectedAstError,
compute_tool_delta,
handle_single_tool,
make_valid_python,
)
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
class Olmo3PythonicToolParser(ToolParser):
"""
Tool call parser for Olmo 3 models that produce tool calls as
@@ -113,15 +110,13 @@ class Olmo3PythonicToolParser(ToolParser):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None,
)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
raise UnexpectedAstError("Tool output must be a list of function calls")
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
@@ -151,7 +146,7 @@ class Olmo3PythonicToolParser(ToolParser):
if current_text.endswith("</function_calls>"):
current_text = current_text[: -len("</function_calls>")]
valid_and_added_text = _make_valid_python(current_text)
valid_and_added_text = make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
@@ -166,11 +161,11 @@ class Olmo3PythonicToolParser(ToolParser):
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts
):
raise _UnexpectedAstError(
raise UnexpectedAstError(
"Tool output must be a sequence of newline-separated calls"
)
tool_calls = [
_handle_single_tool(e) # type: ignore
handle_single_tool(e) # type: ignore
for e in parsed.elts
]
@@ -194,7 +189,7 @@ class Olmo3PythonicToolParser(ToolParser):
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(
delta = compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
)
@@ -228,141 +223,3 @@ class Olmo3PythonicToolParser(ToolParser):
"Skipping chunk as a result of tool streaming extraction error"
)
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
# The model may return function calls where the values are null/true/false
# because the system prompt has API description in json.
elif isinstance(val, ast.Name) and val.id in ["null", "true", "false"]:
if val.id == "null":
return None
elif val.id == "true":
return True
elif val.id == "false":
return False
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
),
)
def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)

View File

@@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from transformers import PreTrainedTokenizerBase
@@ -14,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.tool_parsers.utils import (
UnexpectedAstError,
compute_tool_delta,
handle_single_tool,
make_valid_python,
)
logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
class PythonicToolParser(ToolParser):
"""
Tool call parser for models that produce tool calls in a pythonic style,
@@ -99,15 +95,13 @@ class PythonicToolParser(ToolParser):
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None,
)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
raise UnexpectedAstError("Tool output must be a list of function calls")
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
@@ -129,7 +123,7 @@ class PythonicToolParser(ToolParser):
return DeltaMessage(content=delta_text)
try:
valid_and_added_text = _make_valid_python(current_text)
valid_and_added_text = make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
@@ -139,11 +133,9 @@ class PythonicToolParser(ToolParser):
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts
):
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
raise UnexpectedAstError("Tool output must be a list of function calls")
tool_calls = [
_handle_single_tool(e) # type: ignore
handle_single_tool(e) # type: ignore
for e in parsed.elts
]
@@ -169,7 +161,7 @@ class PythonicToolParser(ToolParser):
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(
delta = compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
)
@@ -203,132 +195,3 @@ class PythonicToolParser(ToolParser):
"Skipping chunk as a result of tool streaming extraction error"
)
return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
),
)
def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from json import JSONDecodeError, JSONDecoder
from typing import Any
@@ -17,6 +18,15 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaToolCall,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def find_common_prefix(s1: str, s2: str) -> str:
@@ -212,3 +222,202 @@ def get_json_schema_from_tools(
return _get_json_schema_from_tools(tools)
# tool_choice: "auto"
return None
# ---------------------------------------------------------------------------
# Shared utilities for pythonic-style tool call parsers
# (PythonicToolParser, Llama4PythonicToolParser, Olmo3PythonicToolParser)
# ---------------------------------------------------------------------------
class UnexpectedAstError(Exception):
"""Raised when the AST structure does not match the expected
pythonic tool call format."""
pass
_JSON_NAME_LITERALS = {
"null": None,
"true": True,
"false": False,
}
def get_parameter_value(val: ast.expr) -> Any:
"""Extract a Python literal value from an AST expression node.
Handles constants, dicts, lists, and JSON-style name literals
(null, true, false) that some models produce instead of Python
literals (None, True, False).
Raises:
UnexpectedAstError: If the AST node is not a supported literal type.
"""
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
logger.warning(
"Dict argument keys are not all literals: %s",
ast.dump(val),
)
raise UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [get_parameter_value(v) for v in val.elts]
elif isinstance(val, ast.Name) and val.id in _JSON_NAME_LITERALS:
return _JSON_NAME_LITERALS[val.id]
else:
logger.warning(
"Unsupported AST node type in tool call arguments: %s",
ast.dump(val),
)
raise UnexpectedAstError("Tool call arguments must be literals")
def handle_single_tool(call: ast.Call) -> ToolCall:
"""Convert a single AST function call node into a ToolCall object.
Raises:
UnexpectedAstError: If the call node does not have a simple
function name (e.g. it's an attribute access or subscript).
"""
if not isinstance(call.func, ast.Name):
logger.warning(
"Tool call has non-simple function name: %s",
ast.dump(call.func),
)
raise UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name,
arguments=json.dumps(arguments, ensure_ascii=False),
),
)
def make_valid_python(text: str) -> tuple[str, str] | None:
"""Attempt to close all open brackets/quotes to make partial Python valid.
Used during streaming to parse incomplete tool call expressions by
appending the necessary closing characters.
Returns:
A tuple of (completed_text, added_suffix) if the text can be
made valid, or None if the text is too incomplete to complete
meaningfully (e.g. mid-parameter-name or mid-dict-key).
Raises:
UnexpectedAstError: If mismatched brackets or parentheses
are detected.
"""
bracket_stack: list[str] = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None
_CLOSING = {"[": "]", "(": ")", "{": "}", "'": "'", '"': '"'}
added_text = ""
for char in reversed(bracket_stack):
added_text += _CLOSING[char]
return text + added_text, added_text
def compute_tool_delta(
previously_sent_args: str,
new_call: ToolCall,
index: int,
withheld_suffix: str,
) -> DeltaToolCall | None:
"""Compute the incremental delta between previously streamed arguments
and the current tool call state.
Returns:
A DeltaToolCall with only the new argument characters, or None
if there is no difference from what was previously sent.
"""
new_call_args = new_call.function.arguments
if withheld_suffix:
if not new_call_args.endswith(withheld_suffix):
msg = (
f"Tool call arguments '{new_call_args}' do not end with "
f"expected withheld suffix '{withheld_suffix}'"
)
logger.error(msg)
raise ValueError(msg)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None,
index=index,
function=DeltaFunctionCall(arguments=arg_diff),
)
if arg_diff
else None
)