GLM-5.1 tool parser with incremental streaming support
This commit is contained in:
5
Dockerfile
Normal file
5
Dockerfile
Normal file
@@ -0,0 +1,5 @@
|
||||
ARG BASE_IMAGE=vllm/vllm-openai:glm51-cu130
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
COPY glm4_moe_tool_parser.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/glm4_moe_tool_parser.py
|
||||
COPY utils.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/utils.py
|
||||
62
Jenkinsfile
vendored
Normal file
62
Jenkinsfile
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
pipeline {
|
||||
agent any
|
||||
|
||||
environment {
|
||||
REGISTRY = 'atl.vultrcr.com/vllm'
|
||||
IMAGE_NAME = 'vllm-glm51-patched'
|
||||
BASE_IMAGE = 'vllm/vllm-openai:glm51-cu130'
|
||||
}
|
||||
|
||||
parameters {
|
||||
string(name: 'IMAGE_TAG', defaultValue: 'latest', description: 'Docker image tag')
|
||||
string(name: 'GIT_REPO', defaultValue: '', description: 'Git repository URL (optional, uses workspace if empty)')
|
||||
string(name: 'GIT_BRANCH', defaultValue: 'master', description: 'Git branch to build')
|
||||
}
|
||||
|
||||
stages {
|
||||
stage('Checkout') {
|
||||
steps {
|
||||
script {
|
||||
if (params.GIT_REPO) {
|
||||
git url: params.GIT_REPO, branch: params.GIT_BRANCH
|
||||
}
|
||||
// Otherwise use workspace already checked out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage('Build') {
|
||||
steps {
|
||||
script {
|
||||
docker.withRegistry("https://${REGISTRY}", 'ATL_VCR_VLLM') {
|
||||
sh """
|
||||
docker build \
|
||||
--build-arg BASE_IMAGE=${BASE_IMAGE} \
|
||||
-t ${REGISTRY}/${IMAGE_NAME}:${params.IMAGE_TAG} \
|
||||
.
|
||||
"""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage('Push') {
|
||||
steps {
|
||||
script {
|
||||
docker.withRegistry("https://${REGISTRY}", 'ATL_VCR_VLLM') {
|
||||
docker.image("${REGISTRY}/${IMAGE_NAME}:${params.IMAGE_TAG}").push()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
post {
|
||||
success {
|
||||
echo "✅ Image pushed: ${REGISTRY}/${IMAGE_NAME}:${params.IMAGE_TAG}"
|
||||
}
|
||||
failure {
|
||||
echo "❌ Build failed"
|
||||
}
|
||||
}
|
||||
}
|
||||
67
README.md
Normal file
67
README.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# vLLM GLM Tool Parser Patch
|
||||
|
||||
## Purpose
|
||||
|
||||
Patches vLLM's GLM-4/GLM-5.1 tool parser to fix a streaming issue where long string parameters are buffered entirely before being emitted, causing multi-second delays.
|
||||
|
||||
## The Problem
|
||||
|
||||
GLM models emit tool calls in a special XML-like format:
|
||||
|
||||
```
|
||||
.tool_name
|
||||
param_nameparam_value
|
||||
```
|
||||
|
||||
The upstream parser (as of vLLM issue #32829) buffers string values until the closing tag arrives. For long strings (e.g., 4000+ characters of code), users see nothing until the entire value is complete — not true streaming.
|
||||
|
||||
## The Fix
|
||||
|
||||
`glm4_moe_tool_parser.py` implements incremental string streaming:
|
||||
|
||||
- Re-parses `` regions on each streaming call
|
||||
- Diffs against previously sent content
|
||||
- Emits only new characters as they arrive
|
||||
- String values now stream character-by-character
|
||||
|
||||
## Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `glm4_moe_tool_parser.py` | Fixed tool parser with incremental streaming |
|
||||
| `utils.py` | Utility functions for partial JSON/tag handling |
|
||||
| `Dockerfile` | Overlays patched files onto base image |
|
||||
| `Jenkinsfile` | CI/CD pipeline for building and pushing |
|
||||
|
||||
## Deployment
|
||||
|
||||
### Jenkins Pipeline
|
||||
|
||||
Build via Jenkins:
|
||||
|
||||
```bash
|
||||
curl -X POST "https://jenkins.sweetapi.com/job/vllm-glm-build/buildWithParameters" \
|
||||
-u "admin:TOKEN" \
|
||||
-d "IMAGE_TAG=latest"
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `IMAGE_TAG` - Docker image tag (default: `latest`)
|
||||
- `GIT_REPO` - Git repository URL (optional, uses workspace if empty)
|
||||
- `GIT_BRANCH` - Git branch to build (default: `master`)
|
||||
|
||||
### Manual Build
|
||||
|
||||
```bash
|
||||
docker build -t atl.vultrcr.com/vllm/vllm-glm51-patched:latest .
|
||||
docker push atl.vultrcr.com/vllm/vllm-glm51-patched:latest
|
||||
```
|
||||
|
||||
### Images
|
||||
|
||||
- Base: `vllm/vllm-openai:glm51-cu130`
|
||||
- Output: `atl.vultrcr.com/vllm/vllm-glm51-patched:<tag>`
|
||||
|
||||
## Related
|
||||
|
||||
- vLLM Issue #32829 (streaming long string parameters)
|
||||
483
glm4_moe_tool_parser.py
Normal file
483
glm4_moe_tool_parser.py
Normal file
@@ -0,0 +1,483 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
GLM-4 Tool Call Parser with incremental string streaming support.
|
||||
|
||||
This parser fixes the streaming issue reported in Issue #32829 where long string
|
||||
parameters (e.g., file content with 4000+ characters of code) are buffered until
|
||||
complete, causing multi-second delays before the user sees any content.
|
||||
|
||||
The fix streams string values incrementally as they arrive, providing a true
|
||||
streaming experience for long content.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
Tool,
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.tool_parsers.utils import partial_tag_overlap
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Glm4MoeModelToolParser(ToolParser):
|
||||
"""Tool parser for GLM-4 models with incremental string streaming.
|
||||
|
||||
On every streaming call the parser re-parses ``current_text`` to find
|
||||
``<tool_call>`` regions, builds the JSON arguments string for each tool
|
||||
call, and diffs against what was previously sent to emit only new content.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
||||
super().__init__(tokenizer, tools)
|
||||
# Stateful streaming fields
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
self.arg_key_start: str = "<arg_key>"
|
||||
self.arg_key_end: str = "</arg_key>"
|
||||
self.arg_val_start: str = "<arg_value>"
|
||||
self.arg_val_end: str = "</arg_value>"
|
||||
|
||||
self.tool_calls_start_token = self.tool_call_start_token
|
||||
|
||||
self.func_call_regex = re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL)
|
||||
self.func_detail_regex = re.compile(
|
||||
r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL
|
||||
)
|
||||
self.func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", re.DOTALL
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
# Pre-compiled pattern for finding the last <arg_key>...</arg_key>
|
||||
# before a partial <arg_value> (used in _build_args_json_so_far).
|
||||
self._arg_key_pattern = re.compile(
|
||||
re.escape(self.arg_key_start) + r"(.*?)" + re.escape(self.arg_key_end),
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Streaming state for re-parse-and-diff approach
|
||||
self._sent_content_idx: int = 0
|
||||
self._tool_call_ids: list[str] = []
|
||||
|
||||
@staticmethod
|
||||
def _deserialize(value: str) -> Any:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _json_escape_string_content(s: str) -> str:
|
||||
"""JSON-escape string content for incremental streaming.
|
||||
|
||||
This escapes the content that goes INSIDE a JSON string (between quotes),
|
||||
not including the surrounding quotes themselves.
|
||||
"""
|
||||
if not s:
|
||||
return ""
|
||||
return json.dumps(s, ensure_ascii=False)[1:-1]
|
||||
|
||||
@staticmethod
|
||||
def _is_string_type(
|
||||
tool_name: str,
|
||||
arg_name: str,
|
||||
tools: list[Tool] | None,
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools:
|
||||
if tool.function.name != tool_name:
|
||||
continue
|
||||
if tool.function.parameters is None:
|
||||
return False
|
||||
arg_type = (
|
||||
tool.function.parameters.get("properties", {})
|
||||
.get(arg_name, {})
|
||||
.get("type", None)
|
||||
)
|
||||
return arg_type == "string"
|
||||
logger.debug("No tool named '%s'.", tool_name)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _tools_enabled(request: ChatCompletionRequest) -> bool:
|
||||
"""Return whether tool parsing should be applied for this request."""
|
||||
try:
|
||||
tools = getattr(request, "tools", None)
|
||||
tool_choice = getattr(request, "tool_choice", None)
|
||||
return bool(tools) and tool_choice != "none"
|
||||
except Exception:
|
||||
logger.exception("Failed to determine if tools are enabled.")
|
||||
return False
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest | ResponsesRequest
|
||||
) -> ChatCompletionRequest | ResponsesRequest:
|
||||
"""Adjust request parameters for tool call token handling."""
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# Ensure tool call tokens (<tool_call>, </tool_call>) are not skipped
|
||||
# during decoding. Even though they are not marked as special tokens,
|
||||
# setting skip_special_tokens=False ensures proper handling in
|
||||
# transformers 5.x where decoding behavior may have changed.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
matched_tool_calls = self.func_call_regex.findall(model_output)
|
||||
logger.debug("model_output: %s", model_output)
|
||||
try:
|
||||
tool_calls: list[ToolCall] = []
|
||||
for match in matched_tool_calls:
|
||||
tc_detail = self.func_detail_regex.search(match)
|
||||
if not tc_detail:
|
||||
logger.warning(
|
||||
"Failed to parse tool call details from: %s",
|
||||
match,
|
||||
)
|
||||
continue
|
||||
tc_name = tc_detail.group(1).strip()
|
||||
tc_args = tc_detail.group(2)
|
||||
pairs = self.func_arg_regex.findall(tc_args) if tc_args else []
|
||||
arg_dct: dict[str, Any] = {}
|
||||
for key, value in pairs:
|
||||
arg_key = key.strip()
|
||||
arg_val = value.strip()
|
||||
if not self._is_string_type(tc_name, arg_key, self.tools):
|
||||
arg_val = self._deserialize(arg_val)
|
||||
logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val)
|
||||
arg_dct[arg_key] = arg_val
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tc_name,
|
||||
arguments=json.dumps(arg_dct, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to extract tool call spec")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
else:
|
||||
if len(tool_calls) > 0:
|
||||
content: str | None = model_output[
|
||||
: model_output.find(self.tool_calls_start_token)
|
||||
]
|
||||
# Normalize empty/whitespace-only content to None
|
||||
if not content or not content.strip():
|
||||
content = None
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True, tool_calls=tool_calls, content=content
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def _extract_content(self, current_text: str) -> str | None:
|
||||
"""Return unsent non-tool-call text, or None.
|
||||
|
||||
Collects all text outside ``<tool_call>...</tool_call>`` regions,
|
||||
including text between consecutive tool calls. Holds back any
|
||||
suffix that could be a partial ``<tool_call>`` tag.
|
||||
"""
|
||||
# Build the "sendable index" — the furthest point we can send
|
||||
# content up to. We scan through the text collecting segments
|
||||
# that are outside tool-call regions.
|
||||
content_segments: list[str] = []
|
||||
pos = self._sent_content_idx
|
||||
|
||||
while pos < len(current_text):
|
||||
start = current_text.find(self.tool_call_start_token, pos)
|
||||
if start == -1:
|
||||
# No more tool calls — send up to (len - partial-tag overlap)
|
||||
tail = current_text[pos:]
|
||||
overlap = partial_tag_overlap(tail, self.tool_call_start_token)
|
||||
sendable = tail[: len(tail) - overlap] if overlap else tail
|
||||
if sendable:
|
||||
content_segments.append(sendable)
|
||||
pos = len(current_text) - overlap
|
||||
break
|
||||
|
||||
# Text before this <tool_call>
|
||||
if start > pos:
|
||||
content_segments.append(current_text[pos:start])
|
||||
|
||||
# Skip past the </tool_call> (or to end if incomplete)
|
||||
end = current_text.find(self.tool_call_end_token, start)
|
||||
if end != -1:
|
||||
pos = end + len(self.tool_call_end_token)
|
||||
else:
|
||||
# Incomplete tool call — nothing more to send
|
||||
pos = start
|
||||
break
|
||||
|
||||
if content_segments:
|
||||
self._sent_content_idx = pos
|
||||
return "".join(content_segments)
|
||||
# Even if no content, advance past completed tool-call regions
|
||||
if pos > self._sent_content_idx:
|
||||
self._sent_content_idx = pos
|
||||
return None
|
||||
|
||||
def _extract_tool_call_regions(self, text: str) -> list[tuple[str, bool]]:
|
||||
"""Extract ``(inner_text, is_complete)`` for each ``<tool_call>`` region."""
|
||||
results: list[tuple[str, bool]] = []
|
||||
pos = 0
|
||||
while True:
|
||||
start = text.find(self.tool_call_start_token, pos)
|
||||
if start == -1:
|
||||
break
|
||||
inner_start = start + len(self.tool_call_start_token)
|
||||
end = text.find(self.tool_call_end_token, inner_start)
|
||||
if end != -1:
|
||||
results.append((text[inner_start:end], True))
|
||||
pos = end + len(self.tool_call_end_token)
|
||||
else:
|
||||
# Incomplete tool call — strip partial </tool_call> suffix
|
||||
raw = text[inner_start:]
|
||||
overlap = partial_tag_overlap(raw, self.tool_call_end_token)
|
||||
if overlap:
|
||||
raw = raw[:-overlap]
|
||||
results.append((raw, False))
|
||||
break
|
||||
return results
|
||||
|
||||
def _extract_tool_name_from_region(self, inner_text: str) -> str | None:
|
||||
"""Extract the tool name from the beginning of a tool-call region.
|
||||
|
||||
The name is everything before the first ``\\n`` or ``<arg_key>``.
|
||||
Returns ``None`` if the name hasn't fully arrived yet.
|
||||
"""
|
||||
nl = inner_text.find("\n")
|
||||
ak = inner_text.find(self.arg_key_start)
|
||||
candidates = [i for i in [nl, ak] if i != -1]
|
||||
if not candidates:
|
||||
return None
|
||||
cut = min(candidates)
|
||||
name = inner_text[:cut].strip()
|
||||
return name if name else None
|
||||
|
||||
def _build_args_json_so_far(
|
||||
self,
|
||||
tool_name: str,
|
||||
inner_text: str,
|
||||
is_complete: bool,
|
||||
) -> str:
|
||||
"""Build the JSON arguments string from the XML pairs seen so far.
|
||||
|
||||
For complete ``<arg_key>/<arg_value>`` pairs the value is fully
|
||||
formatted. For the last argument whose ``<arg_value>`` has been
|
||||
opened but not closed, the partial string content is included
|
||||
(JSON-escaped, with an opening ``"`` but no closing ``"``).
|
||||
|
||||
The closing ``}`` is only appended when ``is_complete`` is True
|
||||
(i.e. the ``</tool_call>`` tag has arrived).
|
||||
"""
|
||||
# Find all complete arg pairs
|
||||
pairs = self.func_arg_regex.findall(inner_text)
|
||||
|
||||
parts: list[str] = []
|
||||
for key, value in pairs:
|
||||
key = key.strip()
|
||||
key_json = json.dumps(key, ensure_ascii=False)
|
||||
if self._is_string_type(tool_name, key, self.tools):
|
||||
# Don't strip string values — whitespace is significant
|
||||
# and must match the partial-value path for diffing.
|
||||
val_json = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
val_json = json.dumps(
|
||||
self._deserialize(value.strip()), ensure_ascii=False
|
||||
)
|
||||
parts.append(f"{key_json}: {val_json}")
|
||||
|
||||
# Check for a partial (incomplete) arg value
|
||||
# Find the last <arg_value> that isn't closed
|
||||
last_val_start = inner_text.rfind(self.arg_val_start)
|
||||
last_val_end = inner_text.rfind(self.arg_val_end)
|
||||
has_partial_value = last_val_start != -1 and (
|
||||
last_val_end == -1 or last_val_end < last_val_start
|
||||
)
|
||||
|
||||
if has_partial_value:
|
||||
# Find the key for this partial value
|
||||
# Look for the last <arg_key>...</arg_key> before this <arg_value>
|
||||
last_key_match = None
|
||||
for m in self._arg_key_pattern.finditer(inner_text[:last_val_start]):
|
||||
last_key_match = m
|
||||
|
||||
if last_key_match:
|
||||
partial_key = last_key_match.group(1).strip()
|
||||
partial_content_start = last_val_start + len(self.arg_val_start)
|
||||
partial_content = inner_text[partial_content_start:]
|
||||
|
||||
# Hold back any partial </arg_value> suffix
|
||||
overlap = partial_tag_overlap(partial_content, self.arg_val_end)
|
||||
if overlap:
|
||||
partial_content = partial_content[:-overlap]
|
||||
|
||||
key_json = json.dumps(partial_key, ensure_ascii=False)
|
||||
if is_complete:
|
||||
# Tool call finished but </arg_value> is missing
|
||||
# (malformed output). Treat partial as complete value
|
||||
# so the diff naturally closes any open quotes.
|
||||
if self._is_string_type(tool_name, partial_key, self.tools):
|
||||
val_json = json.dumps(partial_content, ensure_ascii=False)
|
||||
else:
|
||||
val_json = json.dumps(
|
||||
self._deserialize(partial_content.strip()),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
parts.append(f"{key_json}: {val_json}")
|
||||
elif self._is_string_type(tool_name, partial_key, self.tools):
|
||||
escaped = self._json_escape_string_content(partial_content)
|
||||
# Open quote but no close — more content may arrive
|
||||
parts.append(f'{key_json}: "{escaped}')
|
||||
else:
|
||||
# Non-string partial: include raw content, no wrapping
|
||||
parts.append(f"{key_json}: {partial_content}")
|
||||
|
||||
if not parts:
|
||||
return "{}" if is_complete else ""
|
||||
|
||||
joined = "{" + ", ".join(parts)
|
||||
if is_complete:
|
||||
joined += "}"
|
||||
return joined
|
||||
|
||||
def _compute_args_diff(self, index: int, args_so_far: str) -> str | None:
|
||||
"""Return new argument text not yet sent for tool *index*, or None."""
|
||||
if not args_so_far or len(args_so_far) <= len(
|
||||
self.streamed_args_for_tool[index]
|
||||
):
|
||||
return None
|
||||
diff = args_so_far[len(self.streamed_args_for_tool[index]) :]
|
||||
self.streamed_args_for_tool[index] = args_so_far
|
||||
self.prev_tool_call_arr[index]["arguments"] = args_so_far
|
||||
return diff
|
||||
|
||||
def _ensure_tool_state_for(self, index: int) -> None:
|
||||
"""Grow state arrays so that *index* is valid."""
|
||||
while len(self._tool_call_ids) <= index:
|
||||
self._tool_call_ids.append(
|
||||
make_tool_call_id(id_type="random", func_name=None, idx=None)
|
||||
)
|
||||
while len(self.streamed_args_for_tool) <= index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
while len(self.prev_tool_call_arr) <= index:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
if not self._tools_enabled(request):
|
||||
return DeltaMessage(content=delta_text) if delta_text else None
|
||||
|
||||
content = self._extract_content(current_text)
|
||||
regions = self._extract_tool_call_regions(current_text)
|
||||
tool_call_deltas: list[DeltaToolCall] = []
|
||||
|
||||
for i, (inner_text, is_complete) in enumerate(regions):
|
||||
self._ensure_tool_state_for(i)
|
||||
|
||||
# Extract tool name
|
||||
tool_name = self._extract_tool_name_from_region(inner_text)
|
||||
if not tool_name:
|
||||
break
|
||||
|
||||
# Emit tool name (once per tool call)
|
||||
if "name" not in self.prev_tool_call_arr[i]:
|
||||
self.prev_tool_call_arr[i]["name"] = tool_name
|
||||
tool_call_deltas.append(
|
||||
DeltaToolCall(
|
||||
index=i,
|
||||
id=self._tool_call_ids[i],
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_name,
|
||||
arguments="",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
)
|
||||
|
||||
# Build args JSON so far, diff, emit
|
||||
args_so_far = self._build_args_json_so_far(
|
||||
tool_name, inner_text, is_complete
|
||||
)
|
||||
diff = self._compute_args_diff(i, args_so_far)
|
||||
if diff:
|
||||
tool_call_deltas.append(
|
||||
DeltaToolCall(
|
||||
index=i,
|
||||
function=DeltaFunctionCall(arguments=diff).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Update current_tool_id for serving layer compatibility
|
||||
if regions:
|
||||
self.current_tool_id = len(regions) - 1
|
||||
|
||||
if content or tool_call_deltas:
|
||||
return DeltaMessage(
|
||||
content=content,
|
||||
tool_calls=tool_call_deltas,
|
||||
)
|
||||
return None
|
||||
438
utils.py
Normal file
438
utils.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import partial_json_parser
|
||||
from openai.types.responses import (
|
||||
FunctionTool,
|
||||
ToolChoiceFunction,
|
||||
)
|
||||
from openai.types.responses.tool import Tool as ResponsesTool
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaToolCall,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def partial_tag_overlap(text: str, tag: str) -> int:
|
||||
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
|
||||
|
||||
E.g. text ending in ``"<tool_"`` returns 6 when tag is ``"<tool_call>"``.
|
||||
Returns 0 when there is no overlap.
|
||||
"""
|
||||
max_check = min(len(tag) - 1, len(text))
|
||||
for k in range(max_check, 0, -1):
|
||||
if text.endswith(tag[:k]):
|
||||
return k
|
||||
return 0
|
||||
|
||||
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common prefix that is shared between two strings, if there is one.
|
||||
Order of arguments is NOT important.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely.
|
||||
|
||||
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
||||
'{"fruit": "ap'
|
||||
"""
|
||||
prefix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def find_common_suffix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common suffix shared between two strings, if there is one. Order of
|
||||
arguments is NOT important.
|
||||
Stops when the suffix ends OR it hits an alphanumeric character
|
||||
|
||||
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
||||
"""
|
||||
suffix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(1, min_length + 1):
|
||||
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
||||
suffix = s1[-i] + suffix
|
||||
else:
|
||||
break
|
||||
return suffix
|
||||
|
||||
|
||||
def extract_intermediate_diff(curr: str, old: str) -> str:
|
||||
"""
|
||||
Given two strings, extract the difference in the middle between two strings
|
||||
that are known to have a common prefix and/or suffix.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely. The order of arguments IS
|
||||
important - the new version of the partially-parsed JSON must be the first
|
||||
argument, and the secnod argument must be from the previous generation.
|
||||
|
||||
What it returns, is tokens that should be streamed to the client.
|
||||
|
||||
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
||||
-> 'ple'
|
||||
|
||||
"""
|
||||
suffix = find_common_suffix(curr, old)
|
||||
|
||||
old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
|
||||
prefix = find_common_prefix(curr, old)
|
||||
diff = curr
|
||||
if len(suffix):
|
||||
diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
|
||||
|
||||
if len(prefix):
|
||||
# replace the prefix only once in case it's mirrored
|
||||
diff = diff.replace(prefix, "", 1)
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
# partial_json_parser doesn't support extra data and
|
||||
# JSONDecoder.raw_decode doesn't support partial JSON
|
||||
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
raise
|
||||
|
||||
|
||||
def is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
def consume_space(i: int, s: str) -> int:
|
||||
while i < len(s) and s[i].isspace():
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _extract_tool_info(
|
||||
tool: Tool,
|
||||
) -> tuple[str, dict[str, Any] | None]:
|
||||
if isinstance(tool, FunctionTool):
|
||||
return tool.name, tool.parameters
|
||||
elif isinstance(tool, ChatCompletionToolsParam):
|
||||
return tool.function.name, tool.function.parameters
|
||||
else:
|
||||
raise TypeError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
|
||||
def _get_tool_schema_from_tool(tool: Tool) -> dict:
|
||||
name, params = _extract_tool_info(tool)
|
||||
params = params if params else {"type": "object", "properties": {}}
|
||||
return {
|
||||
"properties": {
|
||||
"name": {"type": "string", "enum": [name]},
|
||||
"parameters": params,
|
||||
},
|
||||
"required": ["name", "parameters"],
|
||||
}
|
||||
|
||||
|
||||
def _get_tool_schema_defs(
|
||||
tools: list[Tool],
|
||||
) -> dict:
|
||||
all_defs: dict[str, dict[str, Any]] = {}
|
||||
for tool in tools:
|
||||
_, params = _extract_tool_info(tool)
|
||||
if params is None:
|
||||
continue
|
||||
defs = params.pop("$defs", {})
|
||||
for def_name, def_schema in defs.items():
|
||||
if def_name in all_defs and all_defs[def_name] != def_schema:
|
||||
raise ValueError(
|
||||
f"Tool definition '{def_name}' has multiple schemas, "
|
||||
"which is not supported."
|
||||
)
|
||||
all_defs[def_name] = def_schema
|
||||
return all_defs
|
||||
|
||||
|
||||
def _get_json_schema_from_tools(
|
||||
tools: list[Tool],
|
||||
) -> dict:
|
||||
json_schema = {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
|
||||
},
|
||||
}
|
||||
json_schema_defs = _get_tool_schema_defs(tools)
|
||||
if json_schema_defs:
|
||||
json_schema["$defs"] = json_schema_defs
|
||||
return json_schema
|
||||
|
||||
|
||||
def get_json_schema_from_tools(
|
||||
tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
|
||||
tools: list[Tool] | None,
|
||||
) -> str | dict | None:
|
||||
# tool_choice: "none"
|
||||
if tool_choice in ("none", None) or tools is None:
|
||||
return None
|
||||
# tool_choice: Forced Function (Responses)
|
||||
if (not isinstance(tool_choice, str)) and isinstance(
|
||||
tool_choice, ToolChoiceFunction
|
||||
):
|
||||
tool_name = tool_choice.name
|
||||
tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
return tool_map[tool_name].parameters
|
||||
# tool_choice: Forced Function (ChatCompletion)
|
||||
if (not isinstance(tool_choice, str)) and isinstance(
|
||||
tool_choice, ChatCompletionNamedToolChoiceParam
|
||||
):
|
||||
tool_name = tool_choice.function.name
|
||||
tool_map = {
|
||||
tool.function.name: tool
|
||||
for tool in tools
|
||||
if isinstance(tool, ChatCompletionToolsParam)
|
||||
}
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
return tool_map[tool_name].function.parameters
|
||||
# tool_choice: "required"
|
||||
if tool_choice == "required":
|
||||
return _get_json_schema_from_tools(tools)
|
||||
# tool_choice: "auto"
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared utilities for pythonic-style tool call parsers
|
||||
# (PythonicToolParser, Llama4PythonicToolParser, Olmo3PythonicToolParser)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class UnexpectedAstError(Exception):
|
||||
"""Raised when the AST structure does not match the expected
|
||||
pythonic tool call format."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
_JSON_NAME_LITERALS = {
|
||||
"null": None,
|
||||
"true": True,
|
||||
"false": False,
|
||||
}
|
||||
|
||||
|
||||
def get_parameter_value(val: ast.expr) -> Any:
|
||||
"""Extract a Python literal value from an AST expression node.
|
||||
|
||||
Handles constants, dicts, lists, and JSON-style name literals
|
||||
(null, true, false) that some models produce instead of Python
|
||||
literals (None, True, False).
|
||||
|
||||
Raises:
|
||||
UnexpectedAstError: If the AST node is not a supported literal type.
|
||||
"""
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
logger.warning(
|
||||
"Dict argument keys are not all literals: %s",
|
||||
ast.dump(val),
|
||||
)
|
||||
raise UnexpectedAstError("Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [get_parameter_value(v) for v in val.elts]
|
||||
elif isinstance(val, ast.Name) and val.id in _JSON_NAME_LITERALS:
|
||||
return _JSON_NAME_LITERALS[val.id]
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported AST node type in tool call arguments: %s",
|
||||
ast.dump(val),
|
||||
)
|
||||
raise UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
"""Convert a single AST function call node into a ToolCall object.
|
||||
|
||||
Raises:
|
||||
UnexpectedAstError: If the call node does not have a simple
|
||||
function name (e.g. it's an attribute access or subscript).
|
||||
"""
|
||||
if not isinstance(call.func, ast.Name):
|
||||
logger.warning(
|
||||
"Tool call has non-simple function name: %s",
|
||||
ast.dump(call.func),
|
||||
)
|
||||
raise UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = get_parameter_value(keyword.value)
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=json.dumps(arguments, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def make_valid_python(text: str) -> tuple[str, str] | None:
|
||||
"""Attempt to close all open brackets/quotes to make partial Python valid.
|
||||
|
||||
Used during streaming to parse incomplete tool call expressions by
|
||||
appending the necessary closing characters.
|
||||
|
||||
Returns:
|
||||
A tuple of (completed_text, added_suffix) if the text can be
|
||||
made valid, or None if the text is too incomplete to complete
|
||||
meaningfully (e.g. mid-parameter-name or mid-dict-key).
|
||||
|
||||
Raises:
|
||||
UnexpectedAstError: If mismatched brackets or parentheses
|
||||
are detected.
|
||||
"""
|
||||
bracket_stack: list[str] = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[: text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[: text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if (
|
||||
bracket_stack
|
||||
and bracket_stack[-1] == "["
|
||||
and not text.endswith("[")
|
||||
and not text.endswith(")")
|
||||
):
|
||||
return None
|
||||
|
||||
_CLOSING = {"[": "]", "(": ")", "{": "}", "'": "'", '"': '"'}
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
added_text += _CLOSING[char]
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def compute_tool_delta(
|
||||
previously_sent_args: str,
|
||||
new_call: ToolCall,
|
||||
index: int,
|
||||
withheld_suffix: str,
|
||||
) -> DeltaToolCall | None:
|
||||
"""Compute the incremental delta between previously streamed arguments
|
||||
and the current tool call state.
|
||||
|
||||
Returns:
|
||||
A DeltaToolCall with only the new argument characters, or None
|
||||
if there is no difference from what was previously sent.
|
||||
"""
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
if not new_call_args.endswith(withheld_suffix):
|
||||
msg = (
|
||||
f"Tool call arguments '{new_call_args}' do not end with "
|
||||
f"expected withheld suffix '{withheld_suffix}'"
|
||||
)
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
new_call_args = new_call_args[: -len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(
|
||||
id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
),
|
||||
)
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args) :]
|
||||
return (
|
||||
DeltaToolCall(
|
||||
id=None,
|
||||
index=index,
|
||||
function=DeltaFunctionCall(arguments=arg_diff),
|
||||
)
|
||||
if arg_diff
|
||||
else None
|
||||
)
|
||||
Reference in New Issue
Block a user