# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ GLM-4/5 Tool Call Parser — fixed version. Fixes applied over the upstream vLLM + sweetapi patch: 1. **func_detail_regex no longer requires a newline** between tool name and first . The model's chat template instructs: {name} with NO mandatory newline, but the original regex used ``[^\\n]*\\n`` which silently failed when the model omitted it. 2. **Zero-argument tool calls no longer crash** (TypeError on NoneType). 3. **extract_tool_calls uses the same robust extraction helpers** as the streaming path, so both paths parse identically. 4. **_extract_tool_name_from_region** is more tolerant of whitespace / formatting variants the model may produce. Drop this file into your vLLM install as a --tool-parser-plugin, or replace the built-in glm4_moe_tool_parser.py. """ import ast import json from collections.abc import Sequence from typing import Any import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) from vllm.tool_parsers.utils import partial_tag_overlap logger = init_logger(__name__) class Glm4MoeModelToolParser(ToolParser): """Tool parser for GLM-4/5 models with incremental string streaming. On every streaming call the parser re-parses ``current_text`` to find ```` regions, builds the JSON arguments string for each tool call, and diffs against what was previously sent to emit only new content. """ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) # Stateful streaming fields self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict[str, Any]] = [] self.current_tool_id: int = -1 self.streamed_args_for_tool: list[str] = [] self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" self.arg_key_start: str = "" self.arg_key_end: str = "" self.arg_val_start: str = "" self.arg_val_end: str = "" self.tool_calls_start_token = self.tool_call_start_token # ---- FIXED regexes ------------------------------------------------ # Match the whole block (unchanged). self.func_call_regex = re.compile( r".*?", re.DOTALL ) # FIX 1: The original regex required a literal \n between tool name # and the body. The model often omits it. We now accept any # whitespace (including none) before the first , and we # make the body group optional so zero-argument calls don't fail. self.func_detail_regex = re.compile( r"\s*" # opening tag + optional whitespace r"([\w.\-]+)" # group 1: tool/function name (word chars, dots, hyphens) r"\s*" # optional whitespace / newline r"((?:.*)?)" # group 2: everything from first onward (may be empty) r"\s*", # closing tag re.DOTALL, ) self.func_arg_regex = re.compile( r"(.*?)\s*(.*?)", re.DOTALL ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction." ) self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) # Pre-compiled pattern for finding the last ... # before a partial (used in _build_args_json_so_far). self._arg_key_pattern = re.compile( re.escape(self.arg_key_start) + r"(.*?)" + re.escape(self.arg_key_end), re.DOTALL, ) # Streaming state for re-parse-and-diff approach self._sent_content_idx: int = 0 self._tool_call_ids: list[str] = [] # ------------------------------------------------------------------ # Static helpers # ------------------------------------------------------------------ @staticmethod def _deserialize(value: str) -> Any: try: return json.loads(value) except json.JSONDecodeError: pass try: return ast.literal_eval(value) except (ValueError, SyntaxError): pass return value @staticmethod def _json_escape_string_content(s: str) -> str: """JSON-escape string content (without surrounding quotes).""" if not s: return "" return json.dumps(s, ensure_ascii=False)[1:-1] @staticmethod def _is_string_type( tool_name: str, arg_name: str, tools: list[Tool] | None, ) -> bool: if tools is None: return False for tool in tools: if tool.function.name != tool_name: continue if tool.function.parameters is None: return False arg_type = ( tool.function.parameters.get("properties", {}) .get(arg_name, {}) .get("type", None) ) return arg_type == "string" logger.debug("No tool named '%s'.", tool_name) return False @staticmethod def _tools_enabled(request: ChatCompletionRequest) -> bool: try: tools = getattr(request, "tools", None) tool_choice = getattr(request, "tool_choice", None) return bool(tools) and tool_choice != "none" except Exception: logger.exception("Failed to determine if tools are enabled.") return False # ------------------------------------------------------------------ # Request adjustment # ------------------------------------------------------------------ def adjust_request( self, request: ChatCompletionRequest | ResponsesRequest ) -> ChatCompletionRequest | ResponsesRequest: request = super().adjust_request(request) if request.tools and request.tool_choice != "none": request.skip_special_tokens = False return request # ------------------------------------------------------------------ # Non-streaming extraction # ------------------------------------------------------------------ def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: matched_tool_calls = self.func_call_regex.findall(model_output) logger.debug("model_output: %s", model_output) try: tool_calls: list[ToolCall] = [] for match in matched_tool_calls: tc_detail = self.func_detail_regex.search(match) if not tc_detail: logger.warning( "Failed to parse tool call details from: %s", match ) continue tc_name = tc_detail.group(1).strip() tc_args_raw = tc_detail.group(2) or "" # FIX 2: default to "" pairs = self.func_arg_regex.findall(tc_args_raw) if tc_args_raw else [] arg_dct: dict[str, Any] = {} for key, value in pairs: arg_key = key.strip() arg_val = value.strip() if not self._is_string_type(tc_name, arg_key, self.tools): arg_val = self._deserialize(arg_val) logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val) arg_dct[arg_key] = arg_val tool_calls.append( ToolCall( type="function", function=FunctionCall( name=tc_name, arguments=json.dumps(arg_dct, ensure_ascii=False), ), ) ) except Exception: logger.exception("Failed to extract tool call spec") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) if tool_calls: content: str | None = model_output[ : model_output.find(self.tool_calls_start_token) ] if not content or not content.strip(): content = None return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content ) return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) # ------------------------------------------------------------------ # Streaming helpers # ------------------------------------------------------------------ def _extract_content(self, current_text: str) -> str | None: content_segments: list[str] = [] pos = self._sent_content_idx while pos < len(current_text): start = current_text.find(self.tool_call_start_token, pos) if start == -1: tail = current_text[pos:] overlap = partial_tag_overlap(tail, self.tool_call_start_token) sendable = tail[: len(tail) - overlap] if overlap else tail if sendable: content_segments.append(sendable) pos = len(current_text) - overlap break if start > pos: content_segments.append(current_text[pos:start]) end = current_text.find(self.tool_call_end_token, start) if end != -1: pos = end + len(self.tool_call_end_token) else: pos = start break if content_segments: self._sent_content_idx = pos return "".join(content_segments) if pos > self._sent_content_idx: self._sent_content_idx = pos return None def _extract_tool_call_regions(self, text: str) -> list[tuple[str, bool]]: results: list[tuple[str, bool]] = [] pos = 0 while True: start = text.find(self.tool_call_start_token, pos) if start == -1: break inner_start = start + len(self.tool_call_start_token) end = text.find(self.tool_call_end_token, inner_start) if end != -1: results.append((text[inner_start:end], True)) pos = end + len(self.tool_call_end_token) else: raw = text[inner_start:] overlap = partial_tag_overlap(raw, self.tool_call_end_token) if overlap: raw = raw[:-overlap] results.append((raw, False)) break return results def _extract_tool_name_from_region(self, inner_text: str) -> str | None: """Extract the tool name from the beginning of a tool-call region. The name is everything before the first ``\\n``, ````, or ````. We also accept the name being the only content (for zero-argument calls that are still in-flight). """ # Strip leading whitespace — model may emit \n after stripped = inner_text.lstrip() if not stripped: return None nl = stripped.find("\n") ak = stripped.find(self.arg_key_start) candidates = [i for i in [nl, ak] if i != -1] if not candidates: # No delimiter yet — if the text looks like a partial name # (only word chars / dots / hyphens), return None to wait. # If it's a complete name with no args (zero-arg call, complete), # it will be handled when is_complete is True. candidate_name = stripped.strip() if re.fullmatch(r'[\w.\-]+', candidate_name): # Could be a complete name or still arriving — return it # so zero-arg complete calls work; the caller checks is_complete. return candidate_name return None cut = min(candidates) name = stripped[:cut].strip() return name if name else None def _build_args_json_so_far( self, tool_name: str, inner_text: str, is_complete: bool, ) -> str: pairs = self.func_arg_regex.findall(inner_text) parts: list[str] = [] for key, value in pairs: key = key.strip() key_json = json.dumps(key, ensure_ascii=False) if self._is_string_type(tool_name, key, self.tools): val_json = json.dumps(value, ensure_ascii=False) else: val_json = json.dumps( self._deserialize(value.strip()), ensure_ascii=False ) parts.append(f"{key_json}: {val_json}") # Check for a partial (incomplete) arg value last_val_start = inner_text.rfind(self.arg_val_start) last_val_end = inner_text.rfind(self.arg_val_end) has_partial_value = last_val_start != -1 and ( last_val_end == -1 or last_val_end < last_val_start ) if has_partial_value: last_key_match = None for m in self._arg_key_pattern.finditer(inner_text[:last_val_start]): last_key_match = m if last_key_match: partial_key = last_key_match.group(1).strip() partial_content_start = last_val_start + len(self.arg_val_start) partial_content = inner_text[partial_content_start:] overlap = partial_tag_overlap(partial_content, self.arg_val_end) if overlap: partial_content = partial_content[:-overlap] key_json = json.dumps(partial_key, ensure_ascii=False) if is_complete: if self._is_string_type(tool_name, partial_key, self.tools): val_json = json.dumps(partial_content, ensure_ascii=False) else: val_json = json.dumps( self._deserialize(partial_content.strip()), ensure_ascii=False, ) parts.append(f"{key_json}: {val_json}") elif self._is_string_type(tool_name, partial_key, self.tools): escaped = self._json_escape_string_content(partial_content) parts.append(f'{key_json}: "{escaped}') else: parts.append(f"{key_json}: {partial_content}") if not parts: return "{}" if is_complete else "" joined = "{" + ", ".join(parts) if is_complete: joined += "}" return joined def _compute_args_diff(self, index: int, args_so_far: str) -> str | None: if not args_so_far or len(args_so_far) <= len( self.streamed_args_for_tool[index] ): return None diff = args_so_far[len(self.streamed_args_for_tool[index]) :] self.streamed_args_for_tool[index] = args_so_far self.prev_tool_call_arr[index]["arguments"] = args_so_far return diff def _ensure_tool_state_for(self, index: int) -> None: while len(self._tool_call_ids) <= index: self._tool_call_ids.append( make_tool_call_id(id_type="random", func_name=None, idx=None) ) while len(self.streamed_args_for_tool) <= index: self.streamed_args_for_tool.append("") while len(self.prev_tool_call_arr) <= index: self.prev_tool_call_arr.append({}) # ------------------------------------------------------------------ # Main streaming entry point # ------------------------------------------------------------------ def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: if not self._tools_enabled(request): return DeltaMessage(content=delta_text) if delta_text else None content = self._extract_content(current_text) regions = self._extract_tool_call_regions(current_text) tool_call_deltas: list[DeltaToolCall] = [] for i, (inner_text, is_complete) in enumerate(regions): self._ensure_tool_state_for(i) tool_name = self._extract_tool_name_from_region(inner_text) if not tool_name: break # Emit tool name (once per tool call) if "name" not in self.prev_tool_call_arr[i]: self.prev_tool_call_arr[i]["name"] = tool_name tool_call_deltas.append( DeltaToolCall( index=i, id=self._tool_call_ids[i], type="function", function=DeltaFunctionCall( name=tool_name, arguments="", ).model_dump(exclude_none=True), ) ) # Build args JSON so far, diff, emit args_so_far = self._build_args_json_so_far( tool_name, inner_text, is_complete ) diff = self._compute_args_diff(i, args_so_far) if diff: tool_call_deltas.append( DeltaToolCall( index=i, function=DeltaFunctionCall(arguments=diff).model_dump( exclude_none=True ), ) ) if regions: self.current_tool_id = len(regions) - 1 if content or tool_call_deltas: return DeltaMessage( content=content, tool_calls=tool_call_deltas, ) return None