Files
vllm-glm/glm4_moe_tool_parser.py
2026-04-09 04:28:22 +00:00

491 lines
19 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GLM-4/5 Tool Call Parser — fixed version.
Fixes applied over the upstream vLLM + sweetapi patch:
1. **func_detail_regex no longer requires a newline** between tool name and
first <arg_key>. The model's chat template instructs:
<tool_call>{name}<arg_key>…</arg_key><arg_value>…</arg_value>…</tool_call>
with NO mandatory newline, but the original regex used ``[^\\n]*\\n`` which
silently failed when the model omitted it.
2. **Zero-argument tool calls no longer crash** (TypeError on NoneType).
3. **extract_tool_calls uses the same robust extraction helpers** as the
streaming path, so both paths parse identically.
4. **_extract_tool_name_from_region** is more tolerant of whitespace /
formatting variants the model may produce.
Drop this file into your vLLM install as a --tool-parser-plugin, or replace
the built-in glm4_moe_tool_parser.py.
"""
import ast
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import partial_tag_overlap
logger = init_logger(__name__)
class Glm4MoeModelToolParser(ToolParser):
"""Tool parser for GLM-4/5 models with incremental string streaming.
On every streaming call the parser re-parses ``current_text`` to find
``<tool_call>`` regions, builds the JSON arguments string for each tool
call, and diffs against what was previously sent to emit only new content.
"""
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# Stateful streaming fields
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict[str, Any]] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = []
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.arg_key_start: str = "<arg_key>"
self.arg_key_end: str = "</arg_key>"
self.arg_val_start: str = "<arg_value>"
self.arg_val_end: str = "</arg_value>"
self.tool_calls_start_token = self.tool_call_start_token
# ---- FIXED regexes ------------------------------------------------
# Match the whole <tool_call>…</tool_call> block (unchanged).
self.func_call_regex = re.compile(
r"<tool_call>.*?</tool_call>", re.DOTALL
)
# FIX 1: The original regex required a literal \n between tool name
# and the body. The model often omits it. We now accept any
# whitespace (including none) before the first <arg_key>, and we
# make the body group optional so zero-argument calls don't fail.
self.func_detail_regex = re.compile(
r"<tool_call>\s*" # opening tag + optional whitespace
r"([\w.\-]+)" # group 1: tool/function name (word chars, dots, hyphens)
r"\s*" # optional whitespace / newline
r"((?:<arg_key>.*)?)" # group 2: everything from first <arg_key> onward (may be empty)
r"\s*</tool_call>", # closing tag
re.DOTALL,
)
self.func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", re.DOTALL
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
# Pre-compiled pattern for finding the last <arg_key>...</arg_key>
# before a partial <arg_value> (used in _build_args_json_so_far).
self._arg_key_pattern = re.compile(
re.escape(self.arg_key_start) + r"(.*?)" + re.escape(self.arg_key_end),
re.DOTALL,
)
# Streaming state for re-parse-and-diff approach
self._sent_content_idx: int = 0
self._tool_call_ids: list[str] = []
# ------------------------------------------------------------------
# Static helpers
# ------------------------------------------------------------------
@staticmethod
def _deserialize(value: str) -> Any:
try:
return json.loads(value)
except json.JSONDecodeError:
pass
try:
return ast.literal_eval(value)
except (ValueError, SyntaxError):
pass
return value
@staticmethod
def _json_escape_string_content(s: str) -> str:
"""JSON-escape string content (without surrounding quotes)."""
if not s:
return ""
return json.dumps(s, ensure_ascii=False)[1:-1]
@staticmethod
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[Tool] | None,
) -> bool:
if tools is None:
return False
for tool in tools:
if tool.function.name != tool_name:
continue
if tool.function.parameters is None:
return False
arg_type = (
tool.function.parameters.get("properties", {})
.get(arg_name, {})
.get("type", None)
)
return arg_type == "string"
logger.debug("No tool named '%s'.", tool_name)
return False
@staticmethod
def _tools_enabled(request: ChatCompletionRequest) -> bool:
try:
tools = getattr(request, "tools", None)
tool_choice = getattr(request, "tool_choice", None)
return bool(tools) and tool_choice != "none"
except Exception:
logger.exception("Failed to determine if tools are enabled.")
return False
# ------------------------------------------------------------------
# Request adjustment
# ------------------------------------------------------------------
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
return request
# ------------------------------------------------------------------
# Non-streaming extraction
# ------------------------------------------------------------------
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
matched_tool_calls = self.func_call_regex.findall(model_output)
logger.debug("model_output: %s", model_output)
try:
tool_calls: list[ToolCall] = []
for match in matched_tool_calls:
tc_detail = self.func_detail_regex.search(match)
if not tc_detail:
logger.warning(
"Failed to parse tool call details from: %s", match
)
continue
tc_name = tc_detail.group(1).strip()
tc_args_raw = tc_detail.group(2) or "" # FIX 2: default to ""
pairs = self.func_arg_regex.findall(tc_args_raw) if tc_args_raw else []
arg_dct: dict[str, Any] = {}
for key, value in pairs:
arg_key = key.strip()
arg_val = value.strip()
if not self._is_string_type(tc_name, arg_key, self.tools):
arg_val = self._deserialize(arg_val)
logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val)
arg_dct[arg_key] = arg_val
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=tc_name,
arguments=json.dumps(arg_dct, ensure_ascii=False),
),
)
)
except Exception:
logger.exception("Failed to extract tool call spec")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
if tool_calls:
content: str | None = model_output[
: model_output.find(self.tool_calls_start_token)
]
if not content or not content.strip():
content = None
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# ------------------------------------------------------------------
# Streaming helpers
# ------------------------------------------------------------------
def _extract_content(self, current_text: str) -> str | None:
content_segments: list[str] = []
pos = self._sent_content_idx
while pos < len(current_text):
start = current_text.find(self.tool_call_start_token, pos)
if start == -1:
tail = current_text[pos:]
overlap = partial_tag_overlap(tail, self.tool_call_start_token)
sendable = tail[: len(tail) - overlap] if overlap else tail
if sendable:
content_segments.append(sendable)
pos = len(current_text) - overlap
break
if start > pos:
content_segments.append(current_text[pos:start])
end = current_text.find(self.tool_call_end_token, start)
if end != -1:
pos = end + len(self.tool_call_end_token)
else:
pos = start
break
if content_segments:
self._sent_content_idx = pos
return "".join(content_segments)
if pos > self._sent_content_idx:
self._sent_content_idx = pos
return None
def _extract_tool_call_regions(self, text: str) -> list[tuple[str, bool]]:
results: list[tuple[str, bool]] = []
pos = 0
while True:
start = text.find(self.tool_call_start_token, pos)
if start == -1:
break
inner_start = start + len(self.tool_call_start_token)
end = text.find(self.tool_call_end_token, inner_start)
if end != -1:
results.append((text[inner_start:end], True))
pos = end + len(self.tool_call_end_token)
else:
raw = text[inner_start:]
overlap = partial_tag_overlap(raw, self.tool_call_end_token)
if overlap:
raw = raw[:-overlap]
results.append((raw, False))
break
return results
def _extract_tool_name_from_region(self, inner_text: str) -> str | None:
"""Extract the tool name from the beginning of a tool-call region.
The name is everything before the first ``\\n``, ``<arg_key>``, or
``</tool_call>``. We also accept the name being the only content
(for zero-argument calls that are still in-flight).
"""
# Strip leading whitespace — model may emit \n after <tool_call>
stripped = inner_text.lstrip()
if not stripped:
return None
nl = stripped.find("\n")
ak = stripped.find(self.arg_key_start)
candidates = [i for i in [nl, ak] if i != -1]
if not candidates:
# No delimiter yet — if the text looks like a partial name
# (only word chars / dots / hyphens), return None to wait.
# If it's a complete name with no args (zero-arg call, complete),
# it will be handled when is_complete is True.
candidate_name = stripped.strip()
if re.fullmatch(r'[\w.\-]+', candidate_name):
# Could be a complete name or still arriving — return it
# so zero-arg complete calls work; the caller checks is_complete.
return candidate_name
return None
cut = min(candidates)
name = stripped[:cut].strip()
return name if name else None
def _build_args_json_so_far(
self,
tool_name: str,
inner_text: str,
is_complete: bool,
) -> str:
pairs = self.func_arg_regex.findall(inner_text)
parts: list[str] = []
for key, value in pairs:
key = key.strip()
key_json = json.dumps(key, ensure_ascii=False)
if self._is_string_type(tool_name, key, self.tools):
val_json = json.dumps(value, ensure_ascii=False)
else:
val_json = json.dumps(
self._deserialize(value.strip()), ensure_ascii=False
)
parts.append(f"{key_json}: {val_json}")
# Check for a partial (incomplete) arg value
last_val_start = inner_text.rfind(self.arg_val_start)
last_val_end = inner_text.rfind(self.arg_val_end)
has_partial_value = last_val_start != -1 and (
last_val_end == -1 or last_val_end < last_val_start
)
if has_partial_value:
last_key_match = None
for m in self._arg_key_pattern.finditer(inner_text[:last_val_start]):
last_key_match = m
if last_key_match:
partial_key = last_key_match.group(1).strip()
partial_content_start = last_val_start + len(self.arg_val_start)
partial_content = inner_text[partial_content_start:]
overlap = partial_tag_overlap(partial_content, self.arg_val_end)
if overlap:
partial_content = partial_content[:-overlap]
key_json = json.dumps(partial_key, ensure_ascii=False)
if is_complete:
if self._is_string_type(tool_name, partial_key, self.tools):
val_json = json.dumps(partial_content, ensure_ascii=False)
else:
val_json = json.dumps(
self._deserialize(partial_content.strip()),
ensure_ascii=False,
)
parts.append(f"{key_json}: {val_json}")
elif self._is_string_type(tool_name, partial_key, self.tools):
escaped = self._json_escape_string_content(partial_content)
parts.append(f'{key_json}: "{escaped}')
else:
parts.append(f"{key_json}: {partial_content}")
if not parts:
return "{}" if is_complete else ""
joined = "{" + ", ".join(parts)
if is_complete:
joined += "}"
return joined
def _compute_args_diff(self, index: int, args_so_far: str) -> str | None:
if not args_so_far or len(args_so_far) <= len(
self.streamed_args_for_tool[index]
):
return None
diff = args_so_far[len(self.streamed_args_for_tool[index]) :]
self.streamed_args_for_tool[index] = args_so_far
self.prev_tool_call_arr[index]["arguments"] = args_so_far
return diff
def _ensure_tool_state_for(self, index: int) -> None:
while len(self._tool_call_ids) <= index:
self._tool_call_ids.append(
make_tool_call_id(id_type="random", func_name=None, idx=None)
)
while len(self.streamed_args_for_tool) <= index:
self.streamed_args_for_tool.append("")
while len(self.prev_tool_call_arr) <= index:
self.prev_tool_call_arr.append({})
# ------------------------------------------------------------------
# Main streaming entry point
# ------------------------------------------------------------------
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if not self._tools_enabled(request):
return DeltaMessage(content=delta_text) if delta_text else None
content = self._extract_content(current_text)
regions = self._extract_tool_call_regions(current_text)
tool_call_deltas: list[DeltaToolCall] = []
for i, (inner_text, is_complete) in enumerate(regions):
self._ensure_tool_state_for(i)
tool_name = self._extract_tool_name_from_region(inner_text)
if not tool_name:
break
# Emit tool name (once per tool call)
if "name" not in self.prev_tool_call_arr[i]:
self.prev_tool_call_arr[i]["name"] = tool_name
tool_call_deltas.append(
DeltaToolCall(
index=i,
id=self._tool_call_ids[i],
type="function",
function=DeltaFunctionCall(
name=tool_name,
arguments="",
).model_dump(exclude_none=True),
)
)
# Build args JSON so far, diff, emit
args_so_far = self._build_args_json_so_far(
tool_name, inner_text, is_complete
)
diff = self._compute_args_diff(i, args_so_far)
if diff:
tool_call_deltas.append(
DeltaToolCall(
index=i,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
)
if regions:
self.current_tool_id = len(regions) - 1
if content or tool_call_deltas:
return DeltaMessage(
content=content,
tool_calls=tool_call_deltas,
)
return None