commit bf66b8708c76a81727a8a17137e90f9762234704 Author: biondizzle Date: Wed Apr 8 18:23:12 2026 +0000 GLM-5.1 tool parser with incremental streaming support diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..b403eee --- /dev/null +++ b/Dockerfile @@ -0,0 +1,5 @@ +ARG BASE_IMAGE=vllm/vllm-openai:glm51-cu130 +FROM ${BASE_IMAGE} + +COPY glm4_moe_tool_parser.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/glm4_moe_tool_parser.py +COPY utils.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/utils.py diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 0000000..1759c06 --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,62 @@ +pipeline { + agent any + + environment { + REGISTRY = 'atl.vultrcr.com/vllm' + IMAGE_NAME = 'vllm-glm51-patched' + BASE_IMAGE = 'vllm/vllm-openai:glm51-cu130' + } + + parameters { + string(name: 'IMAGE_TAG', defaultValue: 'latest', description: 'Docker image tag') + string(name: 'GIT_REPO', defaultValue: '', description: 'Git repository URL (optional, uses workspace if empty)') + string(name: 'GIT_BRANCH', defaultValue: 'master', description: 'Git branch to build') + } + + stages { + stage('Checkout') { + steps { + script { + if (params.GIT_REPO) { + git url: params.GIT_REPO, branch: params.GIT_BRANCH + } + // Otherwise use workspace already checked out + } + } + } + + stage('Build') { + steps { + script { + docker.withRegistry("https://${REGISTRY}", 'ATL_VCR_VLLM') { + sh """ + docker build \ + --build-arg BASE_IMAGE=${BASE_IMAGE} \ + -t ${REGISTRY}/${IMAGE_NAME}:${params.IMAGE_TAG} \ + . + """ + } + } + } + } + + stage('Push') { + steps { + script { + docker.withRegistry("https://${REGISTRY}", 'ATL_VCR_VLLM') { + docker.image("${REGISTRY}/${IMAGE_NAME}:${params.IMAGE_TAG}").push() + } + } + } + } + } + + post { + success { + echo "✅ Image pushed: ${REGISTRY}/${IMAGE_NAME}:${params.IMAGE_TAG}" + } + failure { + echo "❌ Build failed" + } + } +} diff --git a/README.md b/README.md new file mode 100644 index 0000000..6ceff9a --- /dev/null +++ b/README.md @@ -0,0 +1,67 @@ +# vLLM GLM Tool Parser Patch + +## Purpose + +Patches vLLM's GLM-4/GLM-5.1 tool parser to fix a streaming issue where long string parameters are buffered entirely before being emitted, causing multi-second delays. + +## The Problem + +GLM models emit tool calls in a special XML-like format: + +``` +.tool_name +param_nameparam_value +``` + +The upstream parser (as of vLLM issue #32829) buffers string values until the closing tag arrives. For long strings (e.g., 4000+ characters of code), users see nothing until the entire value is complete — not true streaming. + +## The Fix + +`glm4_moe_tool_parser.py` implements incremental string streaming: + +- Re-parses `` regions on each streaming call +- Diffs against previously sent content +- Emits only new characters as they arrive +- String values now stream character-by-character + +## Files + +| File | Description | +|------|-------------| +| `glm4_moe_tool_parser.py` | Fixed tool parser with incremental streaming | +| `utils.py` | Utility functions for partial JSON/tag handling | +| `Dockerfile` | Overlays patched files onto base image | +| `Jenkinsfile` | CI/CD pipeline for building and pushing | + +## Deployment + +### Jenkins Pipeline + +Build via Jenkins: + +```bash +curl -X POST "https://jenkins.sweetapi.com/job/vllm-glm-build/buildWithParameters" \ + -u "admin:TOKEN" \ + -d "IMAGE_TAG=latest" +``` + +Parameters: +- `IMAGE_TAG` - Docker image tag (default: `latest`) +- `GIT_REPO` - Git repository URL (optional, uses workspace if empty) +- `GIT_BRANCH` - Git branch to build (default: `master`) + +### Manual Build + +```bash +docker build -t atl.vultrcr.com/vllm/vllm-glm51-patched:latest . +docker push atl.vultrcr.com/vllm/vllm-glm51-patched:latest +``` + +### Images + +- Base: `vllm/vllm-openai:glm51-cu130` +- Output: `atl.vultrcr.com/vllm/vllm-glm51-patched:` + +## Related + +- vLLM Issue #32829 (streaming long string parameters) diff --git a/glm4_moe_tool_parser.py b/glm4_moe_tool_parser.py new file mode 100644 index 0000000..491c959 --- /dev/null +++ b/glm4_moe_tool_parser.py @@ -0,0 +1,483 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GLM-4 Tool Call Parser with incremental string streaming support. + +This parser fixes the streaming issue reported in Issue #32829 where long string +parameters (e.g., file content with 4000+ characters of code) are buffered until +complete, causing multi-second delays before the user sees any content. + +The fix streams string values incrementally as they arrive, providing a true +streaming experience for long content. +""" + +import ast +import json +from collections.abc import Sequence +from typing import Any + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.responses.protocol import ResponsesRequest +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + Tool, + ToolParser, +) +from vllm.tool_parsers.utils import partial_tag_overlap + +logger = init_logger(__name__) + + +class Glm4MoeModelToolParser(ToolParser): + """Tool parser for GLM-4 models with incremental string streaming. + + On every streaming call the parser re-parses ``current_text`` to find + ```` regions, builds the JSON arguments string for each tool + call, and diffs against what was previously sent to emit only new content. + """ + + def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): + super().__init__(tokenizer, tools) + # Stateful streaming fields + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict[str, Any]] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = [] + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.arg_key_start: str = "" + self.arg_key_end: str = "" + self.arg_val_start: str = "" + self.arg_val_end: str = "" + + self.tool_calls_start_token = self.tool_call_start_token + + self.func_call_regex = re.compile(r".*?", re.DOTALL) + self.func_detail_regex = re.compile( + r"([^\n]*)\n(.*)", re.DOTALL + ) + self.func_arg_regex = re.compile( + r"(.*?)\s*(.*?)", re.DOTALL + ) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction." + ) + + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + # Pre-compiled pattern for finding the last ... + # before a partial (used in _build_args_json_so_far). + self._arg_key_pattern = re.compile( + re.escape(self.arg_key_start) + r"(.*?)" + re.escape(self.arg_key_end), + re.DOTALL, + ) + + # Streaming state for re-parse-and-diff approach + self._sent_content_idx: int = 0 + self._tool_call_ids: list[str] = [] + + @staticmethod + def _deserialize(value: str) -> Any: + try: + return json.loads(value) + except json.JSONDecodeError: + pass + + try: + return ast.literal_eval(value) + except (ValueError, SyntaxError): + pass + + return value + + @staticmethod + def _json_escape_string_content(s: str) -> str: + """JSON-escape string content for incremental streaming. + + This escapes the content that goes INSIDE a JSON string (between quotes), + not including the surrounding quotes themselves. + """ + if not s: + return "" + return json.dumps(s, ensure_ascii=False)[1:-1] + + @staticmethod + def _is_string_type( + tool_name: str, + arg_name: str, + tools: list[Tool] | None, + ) -> bool: + if tools is None: + return False + for tool in tools: + if tool.function.name != tool_name: + continue + if tool.function.parameters is None: + return False + arg_type = ( + tool.function.parameters.get("properties", {}) + .get(arg_name, {}) + .get("type", None) + ) + return arg_type == "string" + logger.debug("No tool named '%s'.", tool_name) + return False + + @staticmethod + def _tools_enabled(request: ChatCompletionRequest) -> bool: + """Return whether tool parsing should be applied for this request.""" + try: + tools = getattr(request, "tools", None) + tool_choice = getattr(request, "tool_choice", None) + return bool(tools) and tool_choice != "none" + except Exception: + logger.exception("Failed to determine if tools are enabled.") + return False + + def adjust_request( + self, request: ChatCompletionRequest | ResponsesRequest + ) -> ChatCompletionRequest | ResponsesRequest: + """Adjust request parameters for tool call token handling.""" + request = super().adjust_request(request) + if request.tools and request.tool_choice != "none": + # Ensure tool call tokens (, ) are not skipped + # during decoding. Even though they are not marked as special tokens, + # setting skip_special_tokens=False ensures proper handling in + # transformers 5.x where decoding behavior may have changed. + request.skip_special_tokens = False + return request + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + matched_tool_calls = self.func_call_regex.findall(model_output) + logger.debug("model_output: %s", model_output) + try: + tool_calls: list[ToolCall] = [] + for match in matched_tool_calls: + tc_detail = self.func_detail_regex.search(match) + if not tc_detail: + logger.warning( + "Failed to parse tool call details from: %s", + match, + ) + continue + tc_name = tc_detail.group(1).strip() + tc_args = tc_detail.group(2) + pairs = self.func_arg_regex.findall(tc_args) if tc_args else [] + arg_dct: dict[str, Any] = {} + for key, value in pairs: + arg_key = key.strip() + arg_val = value.strip() + if not self._is_string_type(tc_name, arg_key, self.tools): + arg_val = self._deserialize(arg_val) + logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val) + arg_dct[arg_key] = arg_val + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=tc_name, + arguments=json.dumps(arg_dct, ensure_ascii=False), + ), + ) + ) + except Exception: + logger.exception("Failed to extract tool call spec") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + else: + if len(tool_calls) > 0: + content: str | None = model_output[ + : model_output.find(self.tool_calls_start_token) + ] + # Normalize empty/whitespace-only content to None + if not content or not content.strip(): + content = None + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=content + ) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def _extract_content(self, current_text: str) -> str | None: + """Return unsent non-tool-call text, or None. + + Collects all text outside ``...`` regions, + including text between consecutive tool calls. Holds back any + suffix that could be a partial ```` tag. + """ + # Build the "sendable index" — the furthest point we can send + # content up to. We scan through the text collecting segments + # that are outside tool-call regions. + content_segments: list[str] = [] + pos = self._sent_content_idx + + while pos < len(current_text): + start = current_text.find(self.tool_call_start_token, pos) + if start == -1: + # No more tool calls — send up to (len - partial-tag overlap) + tail = current_text[pos:] + overlap = partial_tag_overlap(tail, self.tool_call_start_token) + sendable = tail[: len(tail) - overlap] if overlap else tail + if sendable: + content_segments.append(sendable) + pos = len(current_text) - overlap + break + + # Text before this + if start > pos: + content_segments.append(current_text[pos:start]) + + # Skip past the (or to end if incomplete) + end = current_text.find(self.tool_call_end_token, start) + if end != -1: + pos = end + len(self.tool_call_end_token) + else: + # Incomplete tool call — nothing more to send + pos = start + break + + if content_segments: + self._sent_content_idx = pos + return "".join(content_segments) + # Even if no content, advance past completed tool-call regions + if pos > self._sent_content_idx: + self._sent_content_idx = pos + return None + + def _extract_tool_call_regions(self, text: str) -> list[tuple[str, bool]]: + """Extract ``(inner_text, is_complete)`` for each ```` region.""" + results: list[tuple[str, bool]] = [] + pos = 0 + while True: + start = text.find(self.tool_call_start_token, pos) + if start == -1: + break + inner_start = start + len(self.tool_call_start_token) + end = text.find(self.tool_call_end_token, inner_start) + if end != -1: + results.append((text[inner_start:end], True)) + pos = end + len(self.tool_call_end_token) + else: + # Incomplete tool call — strip partial suffix + raw = text[inner_start:] + overlap = partial_tag_overlap(raw, self.tool_call_end_token) + if overlap: + raw = raw[:-overlap] + results.append((raw, False)) + break + return results + + def _extract_tool_name_from_region(self, inner_text: str) -> str | None: + """Extract the tool name from the beginning of a tool-call region. + + The name is everything before the first ``\\n`` or ````. + Returns ``None`` if the name hasn't fully arrived yet. + """ + nl = inner_text.find("\n") + ak = inner_text.find(self.arg_key_start) + candidates = [i for i in [nl, ak] if i != -1] + if not candidates: + return None + cut = min(candidates) + name = inner_text[:cut].strip() + return name if name else None + + def _build_args_json_so_far( + self, + tool_name: str, + inner_text: str, + is_complete: bool, + ) -> str: + """Build the JSON arguments string from the XML pairs seen so far. + + For complete ``/`` pairs the value is fully + formatted. For the last argument whose ```` has been + opened but not closed, the partial string content is included + (JSON-escaped, with an opening ``"`` but no closing ``"``). + + The closing ``}`` is only appended when ``is_complete`` is True + (i.e. the ```` tag has arrived). + """ + # Find all complete arg pairs + pairs = self.func_arg_regex.findall(inner_text) + + parts: list[str] = [] + for key, value in pairs: + key = key.strip() + key_json = json.dumps(key, ensure_ascii=False) + if self._is_string_type(tool_name, key, self.tools): + # Don't strip string values — whitespace is significant + # and must match the partial-value path for diffing. + val_json = json.dumps(value, ensure_ascii=False) + else: + val_json = json.dumps( + self._deserialize(value.strip()), ensure_ascii=False + ) + parts.append(f"{key_json}: {val_json}") + + # Check for a partial (incomplete) arg value + # Find the last that isn't closed + last_val_start = inner_text.rfind(self.arg_val_start) + last_val_end = inner_text.rfind(self.arg_val_end) + has_partial_value = last_val_start != -1 and ( + last_val_end == -1 or last_val_end < last_val_start + ) + + if has_partial_value: + # Find the key for this partial value + # Look for the last ... before this + last_key_match = None + for m in self._arg_key_pattern.finditer(inner_text[:last_val_start]): + last_key_match = m + + if last_key_match: + partial_key = last_key_match.group(1).strip() + partial_content_start = last_val_start + len(self.arg_val_start) + partial_content = inner_text[partial_content_start:] + + # Hold back any partial suffix + overlap = partial_tag_overlap(partial_content, self.arg_val_end) + if overlap: + partial_content = partial_content[:-overlap] + + key_json = json.dumps(partial_key, ensure_ascii=False) + if is_complete: + # Tool call finished but is missing + # (malformed output). Treat partial as complete value + # so the diff naturally closes any open quotes. + if self._is_string_type(tool_name, partial_key, self.tools): + val_json = json.dumps(partial_content, ensure_ascii=False) + else: + val_json = json.dumps( + self._deserialize(partial_content.strip()), + ensure_ascii=False, + ) + parts.append(f"{key_json}: {val_json}") + elif self._is_string_type(tool_name, partial_key, self.tools): + escaped = self._json_escape_string_content(partial_content) + # Open quote but no close — more content may arrive + parts.append(f'{key_json}: "{escaped}') + else: + # Non-string partial: include raw content, no wrapping + parts.append(f"{key_json}: {partial_content}") + + if not parts: + return "{}" if is_complete else "" + + joined = "{" + ", ".join(parts) + if is_complete: + joined += "}" + return joined + + def _compute_args_diff(self, index: int, args_so_far: str) -> str | None: + """Return new argument text not yet sent for tool *index*, or None.""" + if not args_so_far or len(args_so_far) <= len( + self.streamed_args_for_tool[index] + ): + return None + diff = args_so_far[len(self.streamed_args_for_tool[index]) :] + self.streamed_args_for_tool[index] = args_so_far + self.prev_tool_call_arr[index]["arguments"] = args_so_far + return diff + + def _ensure_tool_state_for(self, index: int) -> None: + """Grow state arrays so that *index* is valid.""" + while len(self._tool_call_ids) <= index: + self._tool_call_ids.append( + make_tool_call_id(id_type="random", func_name=None, idx=None) + ) + while len(self.streamed_args_for_tool) <= index: + self.streamed_args_for_tool.append("") + while len(self.prev_tool_call_arr) <= index: + self.prev_tool_call_arr.append({}) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + if not self._tools_enabled(request): + return DeltaMessage(content=delta_text) if delta_text else None + + content = self._extract_content(current_text) + regions = self._extract_tool_call_regions(current_text) + tool_call_deltas: list[DeltaToolCall] = [] + + for i, (inner_text, is_complete) in enumerate(regions): + self._ensure_tool_state_for(i) + + # Extract tool name + tool_name = self._extract_tool_name_from_region(inner_text) + if not tool_name: + break + + # Emit tool name (once per tool call) + if "name" not in self.prev_tool_call_arr[i]: + self.prev_tool_call_arr[i]["name"] = tool_name + tool_call_deltas.append( + DeltaToolCall( + index=i, + id=self._tool_call_ids[i], + type="function", + function=DeltaFunctionCall( + name=tool_name, + arguments="", + ).model_dump(exclude_none=True), + ) + ) + + # Build args JSON so far, diff, emit + args_so_far = self._build_args_json_so_far( + tool_name, inner_text, is_complete + ) + diff = self._compute_args_diff(i, args_so_far) + if diff: + tool_call_deltas.append( + DeltaToolCall( + index=i, + function=DeltaFunctionCall(arguments=diff).model_dump( + exclude_none=True + ), + ) + ) + + # Update current_tool_id for serving layer compatibility + if regions: + self.current_tool_id = len(regions) - 1 + + if content or tool_call_deltas: + return DeltaMessage( + content=content, + tool_calls=tool_call_deltas, + ) + return None diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..4711f05 --- /dev/null +++ b/utils.py @@ -0,0 +1,438 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast +import json +from json import JSONDecodeError, JSONDecoder +from typing import Any, TypeAlias + +import partial_json_parser +from openai.types.responses import ( + FunctionTool, + ToolChoiceFunction, +) +from openai.types.responses.tool import Tool as ResponsesTool +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionNamedToolChoiceParam, + ChatCompletionToolsParam, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaToolCall, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger + +Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool + +logger = init_logger(__name__) + + +def partial_tag_overlap(text: str, tag: str) -> int: + """Length of the longest prefix of *tag* that matches a suffix of *text*. + + E.g. text ending in ``""``. + Returns 0 when there is no overlap. + """ + max_check = min(len(tag) - 1, len(text)) + for k in range(max_check, 0, -1): + if text.endswith(tag[:k]): + return k + return 0 + + +def find_common_prefix(s1: str, s2: str) -> str: + """ + Finds a common prefix that is shared between two strings, if there is one. + Order of arguments is NOT important. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. + + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> + '{"fruit": "ap' + """ + prefix = "" + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def find_common_suffix(s1: str, s2: str) -> str: + """ + Finds a common suffix shared between two strings, if there is one. Order of + arguments is NOT important. + Stops when the suffix ends OR it hits an alphanumeric character + + e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' + """ + suffix = "" + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i] and not s1[-i].isalnum(): + suffix = s1[-i] + suffix + else: + break + return suffix + + +def extract_intermediate_diff(curr: str, old: str) -> str: + """ + Given two strings, extract the difference in the middle between two strings + that are known to have a common prefix and/or suffix. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS + important - the new version of the partially-parsed JSON must be the first + argument, and the secnod argument must be from the previous generation. + + What it returns, is tokens that should be streamed to the client. + + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + -> 'ple' + + """ + suffix = find_common_suffix(curr, old) + + old = old[::-1].replace(suffix[::-1], "", 1)[::-1] + prefix = find_common_prefix(curr, old) + diff = curr + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1] + + if len(prefix): + # replace the prefix only once in case it's mirrored + diff = diff.replace(prefix, "", 1) + + return diff + + +# partial_json_parser doesn't support extra data and +# JSONDecoder.raw_decode doesn't support partial JSON +def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]: + try: + return (partial_json_parser.loads(input_str, flags), len(input_str)) + except JSONDecodeError as e: + if "Extra data" in e.msg: + dec = JSONDecoder() + return dec.raw_decode(input_str) + raise + + +def is_complete_json(input_str: str) -> bool: + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False + + +def consume_space(i: int, s: str) -> int: + while i < len(s) and s[i].isspace(): + i += 1 + return i + + +def _extract_tool_info( + tool: Tool, +) -> tuple[str, dict[str, Any] | None]: + if isinstance(tool, FunctionTool): + return tool.name, tool.parameters + elif isinstance(tool, ChatCompletionToolsParam): + return tool.function.name, tool.function.parameters + else: + raise TypeError(f"Unsupported tool type: {type(tool)}") + + +def _get_tool_schema_from_tool(tool: Tool) -> dict: + name, params = _extract_tool_info(tool) + params = params if params else {"type": "object", "properties": {}} + return { + "properties": { + "name": {"type": "string", "enum": [name]}, + "parameters": params, + }, + "required": ["name", "parameters"], + } + + +def _get_tool_schema_defs( + tools: list[Tool], +) -> dict: + all_defs: dict[str, dict[str, Any]] = {} + for tool in tools: + _, params = _extract_tool_info(tool) + if params is None: + continue + defs = params.pop("$defs", {}) + for def_name, def_schema in defs.items(): + if def_name in all_defs and all_defs[def_name] != def_schema: + raise ValueError( + f"Tool definition '{def_name}' has multiple schemas, " + "which is not supported." + ) + all_defs[def_name] = def_schema + return all_defs + + +def _get_json_schema_from_tools( + tools: list[Tool], +) -> dict: + json_schema = { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "anyOf": [_get_tool_schema_from_tool(tool) for tool in tools], + }, + } + json_schema_defs = _get_tool_schema_defs(tools) + if json_schema_defs: + json_schema["$defs"] = json_schema_defs + return json_schema + + +def get_json_schema_from_tools( + tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam, + tools: list[Tool] | None, +) -> str | dict | None: + # tool_choice: "none" + if tool_choice in ("none", None) or tools is None: + return None + # tool_choice: Forced Function (Responses) + if (not isinstance(tool_choice, str)) and isinstance( + tool_choice, ToolChoiceFunction + ): + tool_name = tool_choice.name + tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)} + if tool_name not in tool_map: + raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") + return tool_map[tool_name].parameters + # tool_choice: Forced Function (ChatCompletion) + if (not isinstance(tool_choice, str)) and isinstance( + tool_choice, ChatCompletionNamedToolChoiceParam + ): + tool_name = tool_choice.function.name + tool_map = { + tool.function.name: tool + for tool in tools + if isinstance(tool, ChatCompletionToolsParam) + } + if tool_name not in tool_map: + raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") + return tool_map[tool_name].function.parameters + # tool_choice: "required" + if tool_choice == "required": + return _get_json_schema_from_tools(tools) + # tool_choice: "auto" + return None + + +# --------------------------------------------------------------------------- +# Shared utilities for pythonic-style tool call parsers +# (PythonicToolParser, Llama4PythonicToolParser, Olmo3PythonicToolParser) +# --------------------------------------------------------------------------- + + +class UnexpectedAstError(Exception): + """Raised when the AST structure does not match the expected + pythonic tool call format.""" + + pass + + +_JSON_NAME_LITERALS = { + "null": None, + "true": True, + "false": False, +} + + +def get_parameter_value(val: ast.expr) -> Any: + """Extract a Python literal value from an AST expression node. + + Handles constants, dicts, lists, and JSON-style name literals + (null, true, false) that some models produce instead of Python + literals (None, True, False). + + Raises: + UnexpectedAstError: If the AST node is not a supported literal type. + """ + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + logger.warning( + "Dict argument keys are not all literals: %s", + ast.dump(val), + ) + raise UnexpectedAstError("Dict tool call arguments must have literal keys") + return { + k.value: get_parameter_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [get_parameter_value(v) for v in val.elts] + elif isinstance(val, ast.Name) and val.id in _JSON_NAME_LITERALS: + return _JSON_NAME_LITERALS[val.id] + else: + logger.warning( + "Unsupported AST node type in tool call arguments: %s", + ast.dump(val), + ) + raise UnexpectedAstError("Tool call arguments must be literals") + + +def handle_single_tool(call: ast.Call) -> ToolCall: + """Convert a single AST function call node into a ToolCall object. + + Raises: + UnexpectedAstError: If the call node does not have a simple + function name (e.g. it's an attribute access or subscript). + """ + if not isinstance(call.func, ast.Name): + logger.warning( + "Tool call has non-simple function name: %s", + ast.dump(call.func), + ) + raise UnexpectedAstError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = get_parameter_value(keyword.value) + return ToolCall( + type="function", + function=FunctionCall( + name=function_name, + arguments=json.dumps(arguments, ensure_ascii=False), + ), + ) + + +def make_valid_python(text: str) -> tuple[str, str] | None: + """Attempt to close all open brackets/quotes to make partial Python valid. + + Used during streaming to parse incomplete tool call expressions by + appending the necessary closing characters. + + Returns: + A tuple of (completed_text, added_suffix) if the text can be + made valid, or None if the text is too incomplete to complete + meaningfully (e.g. mid-parameter-name or mid-dict-key). + + Raises: + UnexpectedAstError: If mismatched brackets or parentheses + are detected. + """ + bracket_stack: list[str] = [] + for index, char in enumerate(text): + if char in {"[", "(", "{"}: + bracket_stack.append(char) + elif char == "]": + if not bracket_stack or bracket_stack.pop() != "[": + raise UnexpectedAstError("Mismatched square brackets") + elif char == ")": + if not bracket_stack or bracket_stack.pop() != "(": + raise UnexpectedAstError("Mismatched parentheses") + elif char == "}": + if not bracket_stack or bracket_stack.pop() != "{": + raise UnexpectedAstError("Mismatched curly braces") + elif char in {"'", '"'}: + if bracket_stack and bracket_stack[-1] == char: + if index > 0 and text[index - 1] == "\\": + pass + else: + bracket_stack.pop() + elif bracket_stack and bracket_stack[-1] in {"'", '"'}: + pass + else: + bracket_stack.append(char) + + text = text.rstrip() + if text.endswith("=") or text.endswith(":"): + return None + if bracket_stack and bracket_stack[-1] == "{": + trailing_dict_text = text[: text.rfind("{")] + num_keys = trailing_dict_text.count(":") + num_values = trailing_dict_text.count(",") + if num_keys <= num_values: + return None + if bracket_stack and bracket_stack[-1] == "(": + trailing_params_text = text[: text.rfind("(")] + num_full_param_names = trailing_params_text.count("=") + num_full_param_values = trailing_params_text.count(",") + if num_full_param_names <= num_full_param_values: + return None + if text.endswith(","): + text = text[:-1] + if ( + bracket_stack + and bracket_stack[-1] == "[" + and not text.endswith("[") + and not text.endswith(")") + ): + return None + + _CLOSING = {"[": "]", "(": ")", "{": "}", "'": "'", '"': '"'} + added_text = "" + for char in reversed(bracket_stack): + added_text += _CLOSING[char] + + return text + added_text, added_text + + +def compute_tool_delta( + previously_sent_args: str, + new_call: ToolCall, + index: int, + withheld_suffix: str, +) -> DeltaToolCall | None: + """Compute the incremental delta between previously streamed arguments + and the current tool call state. + + Returns: + A DeltaToolCall with only the new argument characters, or None + if there is no difference from what was previously sent. + """ + new_call_args = new_call.function.arguments + if withheld_suffix: + if not new_call_args.endswith(withheld_suffix): + msg = ( + f"Tool call arguments '{new_call_args}' do not end with " + f"expected withheld suffix '{withheld_suffix}'" + ) + logger.error(msg) + raise ValueError(msg) + new_call_args = new_call_args[: -len(withheld_suffix)] + if not previously_sent_args: + return DeltaToolCall( + id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + ), + ) + + arg_diff = new_call_args[len(previously_sent_args) :] + return ( + DeltaToolCall( + id=None, + index=index, + function=DeltaFunctionCall(arguments=arg_diff), + ) + if arg_diff + else None + )