# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ GLM-4 Tool Call Parser with incremental string streaming support. This parser fixes the streaming issue reported in Issue #32829 where long string parameters (e.g., file content with 4000+ characters of code) are buffered until complete, causing multi-second delays before the user sees any content. The fix streams string values incrementally as they arrive, providing a true streaming experience for long content. """ import ast import json from collections.abc import Sequence from typing import Any import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) from vllm.tool_parsers.utils import partial_tag_overlap logger = init_logger(__name__) class Glm4MoeModelToolParser(ToolParser): """Tool parser for GLM-4 models with incremental string streaming. On every streaming call the parser re-parses ``current_text`` to find ```` regions, builds the JSON arguments string for each tool call, and diffs against what was previously sent to emit only new content. """ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) # Stateful streaming fields self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict[str, Any]] = [] self.current_tool_id: int = -1 self.streamed_args_for_tool: list[str] = [] self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" self.arg_key_start: str = "" self.arg_key_end: str = "" self.arg_val_start: str = "" self.arg_val_end: str = "" self.tool_calls_start_token = self.tool_call_start_token self.func_call_regex = re.compile(r".*?", re.DOTALL) self.func_detail_regex = re.compile( r"([^\n]*)\n(.*)", re.DOTALL ) self.func_arg_regex = re.compile( r"(.*?)\s*(.*?)", re.DOTALL ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction." ) self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) # Pre-compiled pattern for finding the last ... # before a partial (used in _build_args_json_so_far). self._arg_key_pattern = re.compile( re.escape(self.arg_key_start) + r"(.*?)" + re.escape(self.arg_key_end), re.DOTALL, ) # Streaming state for re-parse-and-diff approach self._sent_content_idx: int = 0 self._tool_call_ids: list[str] = [] @staticmethod def _deserialize(value: str) -> Any: try: return json.loads(value) except json.JSONDecodeError: pass try: return ast.literal_eval(value) except (ValueError, SyntaxError): pass return value @staticmethod def _json_escape_string_content(s: str) -> str: """JSON-escape string content for incremental streaming. This escapes the content that goes INSIDE a JSON string (between quotes), not including the surrounding quotes themselves. """ if not s: return "" return json.dumps(s, ensure_ascii=False)[1:-1] @staticmethod def _is_string_type( tool_name: str, arg_name: str, tools: list[Tool] | None, ) -> bool: if tools is None: return False for tool in tools: if tool.function.name != tool_name: continue if tool.function.parameters is None: return False arg_type = ( tool.function.parameters.get("properties", {}) .get(arg_name, {}) .get("type", None) ) return arg_type == "string" logger.debug("No tool named '%s'.", tool_name) return False @staticmethod def _tools_enabled(request: ChatCompletionRequest) -> bool: """Return whether tool parsing should be applied for this request.""" try: tools = getattr(request, "tools", None) tool_choice = getattr(request, "tool_choice", None) return bool(tools) and tool_choice != "none" except Exception: logger.exception("Failed to determine if tools are enabled.") return False def adjust_request( self, request: ChatCompletionRequest | ResponsesRequest ) -> ChatCompletionRequest | ResponsesRequest: """Adjust request parameters for tool call token handling.""" request = super().adjust_request(request) if request.tools and request.tool_choice != "none": # Ensure tool call tokens (, ) are not skipped # during decoding. Even though they are not marked as special tokens, # setting skip_special_tokens=False ensures proper handling in # transformers 5.x where decoding behavior may have changed. request.skip_special_tokens = False return request def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: matched_tool_calls = self.func_call_regex.findall(model_output) logger.debug("model_output: %s", model_output) try: tool_calls: list[ToolCall] = [] for match in matched_tool_calls: tc_detail = self.func_detail_regex.search(match) if not tc_detail: logger.warning( "Failed to parse tool call details from: %s", match, ) continue tc_name = tc_detail.group(1).strip() tc_args = tc_detail.group(2) pairs = self.func_arg_regex.findall(tc_args) if tc_args else [] arg_dct: dict[str, Any] = {} for key, value in pairs: arg_key = key.strip() arg_val = value.strip() if not self._is_string_type(tc_name, arg_key, self.tools): arg_val = self._deserialize(arg_val) logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val) arg_dct[arg_key] = arg_val tool_calls.append( ToolCall( type="function", function=FunctionCall( name=tc_name, arguments=json.dumps(arg_dct, ensure_ascii=False), ), ) ) except Exception: logger.exception("Failed to extract tool call spec") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) else: if len(tool_calls) > 0: content: str | None = model_output[ : model_output.find(self.tool_calls_start_token) ] # Normalize empty/whitespace-only content to None if not content or not content.strip(): content = None return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content ) return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) def _extract_content(self, current_text: str) -> str | None: """Return unsent non-tool-call text, or None. Collects all text outside ``...`` regions, including text between consecutive tool calls. Holds back any suffix that could be a partial ```` tag. """ # Build the "sendable index" — the furthest point we can send # content up to. We scan through the text collecting segments # that are outside tool-call regions. content_segments: list[str] = [] pos = self._sent_content_idx while pos < len(current_text): start = current_text.find(self.tool_call_start_token, pos) if start == -1: # No more tool calls — send up to (len - partial-tag overlap) tail = current_text[pos:] overlap = partial_tag_overlap(tail, self.tool_call_start_token) sendable = tail[: len(tail) - overlap] if overlap else tail if sendable: content_segments.append(sendable) pos = len(current_text) - overlap break # Text before this if start > pos: content_segments.append(current_text[pos:start]) # Skip past the (or to end if incomplete) end = current_text.find(self.tool_call_end_token, start) if end != -1: pos = end + len(self.tool_call_end_token) else: # Incomplete tool call — nothing more to send pos = start break if content_segments: self._sent_content_idx = pos return "".join(content_segments) # Even if no content, advance past completed tool-call regions if pos > self._sent_content_idx: self._sent_content_idx = pos return None def _extract_tool_call_regions(self, text: str) -> list[tuple[str, bool]]: """Extract ``(inner_text, is_complete)`` for each ```` region.""" results: list[tuple[str, bool]] = [] pos = 0 while True: start = text.find(self.tool_call_start_token, pos) if start == -1: break inner_start = start + len(self.tool_call_start_token) end = text.find(self.tool_call_end_token, inner_start) if end != -1: results.append((text[inner_start:end], True)) pos = end + len(self.tool_call_end_token) else: # Incomplete tool call — strip partial suffix raw = text[inner_start:] overlap = partial_tag_overlap(raw, self.tool_call_end_token) if overlap: raw = raw[:-overlap] results.append((raw, False)) break return results def _extract_tool_name_from_region(self, inner_text: str) -> str | None: """Extract the tool name from the beginning of a tool-call region. The name is everything before the first ``\\n`` or ````. Returns ``None`` if the name hasn't fully arrived yet. """ nl = inner_text.find("\n") ak = inner_text.find(self.arg_key_start) candidates = [i for i in [nl, ak] if i != -1] if not candidates: return None cut = min(candidates) name = inner_text[:cut].strip() return name if name else None def _build_args_json_so_far( self, tool_name: str, inner_text: str, is_complete: bool, ) -> str: """Build the JSON arguments string from the XML pairs seen so far. For complete ``/`` pairs the value is fully formatted. For the last argument whose ```` has been opened but not closed, the partial string content is included (JSON-escaped, with an opening ``"`` but no closing ``"``). The closing ``}`` is only appended when ``is_complete`` is True (i.e. the ```` tag has arrived). """ # Find all complete arg pairs pairs = self.func_arg_regex.findall(inner_text) parts: list[str] = [] for key, value in pairs: key = key.strip() key_json = json.dumps(key, ensure_ascii=False) if self._is_string_type(tool_name, key, self.tools): # Don't strip string values — whitespace is significant # and must match the partial-value path for diffing. val_json = json.dumps(value, ensure_ascii=False) else: val_json = json.dumps( self._deserialize(value.strip()), ensure_ascii=False ) parts.append(f"{key_json}: {val_json}") # Check for a partial (incomplete) arg value # Find the last that isn't closed last_val_start = inner_text.rfind(self.arg_val_start) last_val_end = inner_text.rfind(self.arg_val_end) has_partial_value = last_val_start != -1 and ( last_val_end == -1 or last_val_end < last_val_start ) if has_partial_value: # Find the key for this partial value # Look for the last ... before this last_key_match = None for m in self._arg_key_pattern.finditer(inner_text[:last_val_start]): last_key_match = m if last_key_match: partial_key = last_key_match.group(1).strip() partial_content_start = last_val_start + len(self.arg_val_start) partial_content = inner_text[partial_content_start:] # Hold back any partial suffix overlap = partial_tag_overlap(partial_content, self.arg_val_end) if overlap: partial_content = partial_content[:-overlap] key_json = json.dumps(partial_key, ensure_ascii=False) if is_complete: # Tool call finished but is missing # (malformed output). Treat partial as complete value # so the diff naturally closes any open quotes. if self._is_string_type(tool_name, partial_key, self.tools): val_json = json.dumps(partial_content, ensure_ascii=False) else: val_json = json.dumps( self._deserialize(partial_content.strip()), ensure_ascii=False, ) parts.append(f"{key_json}: {val_json}") elif self._is_string_type(tool_name, partial_key, self.tools): escaped = self._json_escape_string_content(partial_content) # Open quote but no close — more content may arrive parts.append(f'{key_json}: "{escaped}') else: # Non-string partial: include raw content, no wrapping parts.append(f"{key_json}: {partial_content}") if not parts: return "{}" if is_complete else "" joined = "{" + ", ".join(parts) if is_complete: joined += "}" return joined def _compute_args_diff(self, index: int, args_so_far: str) -> str | None: """Return new argument text not yet sent for tool *index*, or None.""" if not args_so_far or len(args_so_far) <= len( self.streamed_args_for_tool[index] ): return None diff = args_so_far[len(self.streamed_args_for_tool[index]) :] self.streamed_args_for_tool[index] = args_so_far self.prev_tool_call_arr[index]["arguments"] = args_so_far return diff def _ensure_tool_state_for(self, index: int) -> None: """Grow state arrays so that *index* is valid.""" while len(self._tool_call_ids) <= index: self._tool_call_ids.append( make_tool_call_id(id_type="random", func_name=None, idx=None) ) while len(self.streamed_args_for_tool) <= index: self.streamed_args_for_tool.append("") while len(self.prev_tool_call_arr) <= index: self.prev_tool_call_arr.append({}) def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: if not self._tools_enabled(request): return DeltaMessage(content=delta_text) if delta_text else None content = self._extract_content(current_text) regions = self._extract_tool_call_regions(current_text) tool_call_deltas: list[DeltaToolCall] = [] for i, (inner_text, is_complete) in enumerate(regions): self._ensure_tool_state_for(i) # Extract tool name tool_name = self._extract_tool_name_from_region(inner_text) if not tool_name: break # Emit tool name (once per tool call) if "name" not in self.prev_tool_call_arr[i]: self.prev_tool_call_arr[i]["name"] = tool_name tool_call_deltas.append( DeltaToolCall( index=i, id=self._tool_call_ids[i], type="function", function=DeltaFunctionCall( name=tool_name, arguments="", ).model_dump(exclude_none=True), ) ) # Build args JSON so far, diff, emit args_so_far = self._build_args_json_so_far( tool_name, inner_text, is_complete ) diff = self._compute_args_diff(i, args_so_far) if diff: tool_call_deltas.append( DeltaToolCall( index=i, function=DeltaFunctionCall(arguments=diff).model_dump( exclude_none=True ), ) ) # Update current_tool_id for serving layer compatibility if regions: self.current_tool_id = len(regions) - 1 if content or tool_call_deltas: return DeltaMessage( content=content, tool_calls=tool_call_deltas, ) return None