# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ DeepSeek-V3.2 Tool Call Parser — re-parse-and-diff version. Adapted from the GLM-4 streaming fix to make the streaming path robust against multi-token deltas produced by MTP speculative decoding. Instead of maintaining incremental state that advances one token at a time, the streaming path re-parses the *entire* current_text on every call, finds all <|DSML|invoke> regions (complete and in-progress), builds a JSON arguments string for each, and diffs against what was previously sent. This makes the parser agnostic to how many tokens arrive per step. Key changes vs. the upstream buffer-until-complete parser: 1. _extract_content() handles partial tag overlaps so content text is never swallowed or duplicated when a tag boundary lands inside a multi-token chunk. 2. _extract_invoke_regions() finds both complete and incomplete invoke blocks, enabling streaming of partial arguments. 3. _build_args_json_so_far() constructs the JSON arguments string incrementally from complete + partial <|DSML|parameter> tags. 4. _compute_args_diff() emits only the newly-added characters. Drop-in replacement: same class name, same interface. """ import json import uuid from collections.abc import Sequence from typing import Any import regex as re from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) logger = init_logger(__name__) def partial_tag_overlap(text: str, tag: str) -> int: """Length of the longest prefix of *tag* that matches a suffix of *text*. E.g. text ending in ``""``. Returns 0 when there is no overlap. """ max_check = min(len(tag) - 1, len(text)) for k in range(max_check, 0, -1): if text.endswith(tag[:k]): return k return 0 class DeepSeekV32ToolParser(ToolParser): """ Re-parse-and-diff tool parser for DeepSeek-V3.2 DSML format. On every streaming call the parser re-parses ``current_text`` to find ``<|DSML|invoke>`` regions, builds the JSON arguments string for each tool call, and diffs against what was previously sent to emit only new content. This is robust against multi-token deltas from MTP / EAGLE speculative decoding. Example tool call format:: <|DSML|function_calls> <|DSML|invoke name="get_weather"> <|DSML|parameter name="location" string="true">杭州 <|DSML|parameter name="date" string="true">2024-01-16 """ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) # ----- Tag constants ----- self.tool_call_start_token: str = "<|DSML|function_calls>" self.tool_call_end_token: str = "" self.invoke_end_token: str = "" self.param_end_token: str = "" # Alias expected by ToolParser base / adjust_request self.tool_calls_start_token = self.tool_call_start_token # ----- Compiled regexes ----- # Matches a complete <|DSML|function_calls>… self.tool_call_complete_regex = re.compile( r"<|DSML|function_calls>(.*?)", re.DOTALL ) # Opening tag of an invoke block — captures the function name. self.invoke_start_regex = re.compile( r'<|DSML|invoke\s+name="([^"]+)"\s*>', re.DOTALL ) # Complete invoke block. self.invoke_complete_regex = re.compile( r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)', re.DOTALL, ) # Complete parameter tag — captures (name, string_attr, value). self.parameter_complete_regex = re.compile( r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>' r"(.*?)" r"", re.DOTALL, ) # Just the opening header of a parameter tag (for partial params). self.parameter_header_regex = re.compile( r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>', re.DOTALL, ) # ----- Streaming state (reset per request) ----- self._sent_content_idx: int = 0 self._tool_call_ids: list[str] = [] self.streamed_args_for_tool: list[str] = [] self.prev_tool_call_arr: list[dict[str, Any]] = [] self.current_tool_id: int = -1 if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction." ) logger.debug( "Successfully initialized %s", self.__class__.__name__ ) # ------------------------------------------------------------------ # Request adjustment # ------------------------------------------------------------------ def adjust_request( self, request: ChatCompletionRequest | ResponsesRequest ) -> ChatCompletionRequest | ResponsesRequest: request = super().adjust_request(request) if request.tools and request.tool_choice != "none": # Ensure DSML tokens are not stripped during decoding. request.skip_special_tokens = False return request # ------------------------------------------------------------------ # Static / utility helpers # ------------------------------------------------------------------ @staticmethod def _tools_enabled(request: ChatCompletionRequest) -> bool: """Check whether tool calling is active for this request.""" try: tools = getattr(request, "tools", None) tool_choice = getattr(request, "tool_choice", None) return bool(tools) and tool_choice != "none" except Exception: logger.exception("Failed to determine if tools are enabled.") return False def _generate_tool_call_id(self) -> str: return f"call_{uuid.uuid4().hex[:24]}" @staticmethod def _json_escape_string_content(s: str) -> str: """JSON-escape a string value (without surrounding quotes).""" if not s: return "" return json.dumps(s, ensure_ascii=False)[1:-1] # ------------------------------------------------------------------ # Type conversion helpers # ------------------------------------------------------------------ def _convert_param_value_checked(self, value: str, param_type: str) -> Any: """Convert a raw string value to the type indicated by *param_type*. Raises on failure so the caller can try the next candidate type. """ if value.lower() == "null": return None param_type = param_type.lower() if param_type in ("string", "str", "text"): return value elif param_type in ("integer", "int"): return int(value) elif param_type in ("number", "float"): val = float(value) return val if val != int(val) else int(val) elif param_type in ("boolean", "bool"): normed = value.strip().lower() if normed not in ("false", "0", "true", "1"): raise ValueError(f"Invalid boolean value: {value!r}") return normed in ("true", "1") elif param_type in ("object", "array"): return json.loads(value) else: return json.loads(value) def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any: """Try each candidate type in turn; fall back to the raw string.""" if not isinstance(param_type, list): param_type = [param_type] for current_type in param_type: try: return self._convert_param_value_checked(value, current_type) except Exception: continue return value def _get_param_schema_type( self, func_name: str, param_name: str ) -> str | list[str]: """Look up the JSON-schema type for a parameter, defaulting to ``"string"``.""" if self.tools: for tool in self.tools: if ( hasattr(tool, "function") and tool.function.name == func_name and hasattr(tool.function, "parameters") ): schema = tool.function.parameters if isinstance(schema, dict) and "properties" in schema: prop = schema["properties"].get(param_name, {}) if isinstance(prop, dict): return prop.get("type", "string") break return "string" def _convert_with_schema( self, func_name: str, param_name: str, value: str ) -> Any: """Convert *value* using the tool schema for *func_name*.*param_name*.""" param_type = self._get_param_schema_type(func_name, param_name) return self._convert_param_value(value, param_type) def _is_string_type(self, func_name: str, param_name: str) -> bool: """Return True if the schema says this parameter is a string.""" ptype = self._get_param_schema_type(func_name, param_name) if isinstance(ptype, list): return "string" in ptype return ptype in ("string", "str", "text") # ------------------------------------------------------------------ # Non-streaming extraction (unchanged logic, shared helpers) # ------------------------------------------------------------------ def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: """Extract tool calls from complete model output (non-streaming).""" if self.tool_call_start_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) try: tool_calls: list[ToolCall] = [] for fc_block in self.tool_call_complete_regex.findall(model_output): for invoke_name, invoke_body in self.invoke_complete_regex.findall( fc_block ): # Parse all parameters in this invoke. raw_params: dict[str, str] = {} for pname, _str_attr, pval in ( self.parameter_complete_regex.findall(invoke_body) ): raw_params[pname] = pval # Convert types via schema. converted: dict[str, Any] = {} for pname, pval in raw_params.items(): converted[pname] = self._convert_with_schema( invoke_name, pname, pval ) tool_calls.append( ToolCall( type="function", function=FunctionCall( name=invoke_name, arguments=json.dumps( converted, ensure_ascii=False ), ), ) ) if not tool_calls: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) first_idx = model_output.find(self.tool_call_start_token) content = model_output[:first_idx] if first_idx > 0 else None return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content ) except Exception: logger.exception("Error extracting tool calls from complete output") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) # ------------------------------------------------------------------ # Streaming helpers — re-parse-and-diff # ------------------------------------------------------------------ def _reset_streaming_state(self) -> None: self._sent_content_idx = 0 self._tool_call_ids.clear() self.streamed_args_for_tool.clear() self.prev_tool_call_arr.clear() self.current_tool_id = -1 def _extract_content(self, current_text: str) -> str | None: """Return any non-tool-call text that hasn't been sent yet. Walks *current_text* from ``_sent_content_idx``, collecting text outside ``<|DSML|function_calls>`` regions. Uses ``partial_tag_overlap`` to avoid emitting bytes that might turn out to be the start of the function-calls tag once the next chunk arrives. """ content_segments: list[str] = [] pos = self._sent_content_idx while pos < len(current_text): start = current_text.find(self.tool_call_start_token, pos) if start == -1: # No (more) tool-call regions — send the tail, minus # any suffix that could be the beginning of the tag. tail = current_text[pos:] overlap = partial_tag_overlap(tail, self.tool_call_start_token) sendable = tail[: len(tail) - overlap] if overlap else tail if sendable: content_segments.append(sendable) pos = len(current_text) - overlap break # Text between previous position and the tag start is content. if start > pos: content_segments.append(current_text[pos:start]) # Skip past the tool-call region. end = current_text.find(self.tool_call_end_token, start) if end != -1: pos = end + len(self.tool_call_end_token) else: # Region still open — park cursor at start, stop. pos = start break if content_segments: self._sent_content_idx = pos return "".join(content_segments) if pos > self._sent_content_idx: self._sent_content_idx = pos return None def _extract_invoke_regions( self, text: str ) -> list[tuple[str, str, bool]]: """Find all invoke blocks inside the function_calls region. Returns a list of ``(func_name, inner_text, is_complete)`` tuples. *inner_text* is everything between the invoke open tag and the close tag (or the end of available text for the last, potentially incomplete, invoke). """ results: list[tuple[str, str, bool]] = [] fc_start = text.find(self.tool_call_start_token) if fc_start == -1: return results region_start = fc_start + len(self.tool_call_start_token) fc_end = text.find(self.tool_call_end_token, region_start) region = text[region_start:fc_end] if fc_end != -1 else text[region_start:] pos = 0 while pos < len(region): inv_match = self.invoke_start_regex.search(region, pos) if not inv_match: break func_name = inv_match.group(1) body_start = inv_match.end() inv_end_pos = region.find(self.invoke_end_token, body_start) if inv_end_pos != -1: # Complete invoke block. body = region[body_start:inv_end_pos] results.append((func_name, body, True)) pos = inv_end_pos + len(self.invoke_end_token) else: # Incomplete — still being generated. body = region[body_start:] overlap = partial_tag_overlap(body, self.invoke_end_token) if overlap: body = body[:-overlap] results.append((func_name, body, False)) break return results def _build_args_json_so_far( self, func_name: str, inner_text: str, is_complete: bool, ) -> str: """Build a JSON arguments string from the parameters found so far. Handles both fully-closed ``<|DSML|parameter>`` tags and the single trailing partial parameter whose value is still being streamed. """ # ---- Collect all fully-closed parameters ---- complete_params = self.parameter_complete_regex.findall(inner_text) parts: list[str] = [] for param_name, string_attr, param_value in complete_params: key_json = json.dumps(param_name, ensure_ascii=False) if string_attr == "true": val_json = json.dumps(param_value, ensure_ascii=False) else: converted = self._convert_with_schema( func_name, param_name, param_value ) val_json = json.dumps(converted, ensure_ascii=False) parts.append(f"{key_json}: {val_json}") # ---- Handle a trailing partial parameter ---- last_param_open = inner_text.rfind("<|DSML|parameter") last_param_close = inner_text.rfind(self.param_end_token) has_partial = last_param_open != -1 and ( last_param_close == -1 or last_param_close < last_param_open ) if has_partial: partial_text = inner_text[last_param_open:] header_match = self.parameter_header_regex.search(partial_text) if header_match: param_name = header_match.group(1) string_attr = header_match.group(2) partial_value = partial_text[header_match.end():] # Strip any bytes that might be the beginning of the # closing tag. overlap = partial_tag_overlap( partial_value, self.param_end_token ) if overlap: partial_value = partial_value[:-overlap] key_json = json.dumps(param_name, ensure_ascii=False) if is_complete: # Invoke is closed — treat whatever we have as final. if string_attr == "true": val_json = json.dumps( partial_value, ensure_ascii=False ) else: converted = self._convert_with_schema( func_name, param_name, partial_value ) val_json = json.dumps(converted, ensure_ascii=False) parts.append(f"{key_json}: {val_json}") elif string_attr == "true" or self._is_string_type( func_name, param_name ): # Stream as an open JSON string (no closing quote). escaped = self._json_escape_string_content(partial_value) parts.append(f'{key_json}: "{escaped}') else: # Non-string — emit raw partial value. parts.append(f"{key_json}: {partial_value}") # ---- Assemble ---- if not parts: return "{}" if is_complete else "" joined = "{" + ", ".join(parts) if is_complete: joined += "}" return joined def _compute_args_diff(self, index: int, args_so_far: str) -> str | None: """Return only the characters in *args_so_far* that haven't been sent yet, or ``None`` if there's nothing new.""" prev = self.streamed_args_for_tool[index] if not args_so_far or len(args_so_far) <= len(prev): return None diff = args_so_far[len(prev):] self.streamed_args_for_tool[index] = args_so_far self.prev_tool_call_arr[index]["arguments"] = args_so_far return diff def _ensure_tool_state_for(self, index: int) -> None: """Grow the streaming-state arrays so *index* is valid.""" while len(self._tool_call_ids) <= index: self._tool_call_ids.append(self._generate_tool_call_id()) while len(self.streamed_args_for_tool) <= index: self.streamed_args_for_tool.append("") while len(self.prev_tool_call_arr) <= index: self.prev_tool_call_arr.append({}) # ------------------------------------------------------------------ # Main streaming entry point # ------------------------------------------------------------------ def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: """Extract tool calls from streaming output using re-parse-and-diff. On every call we: 1. Re-scan *current_text* for content outside tool-call regions. 2. Find all ``<|DSML|invoke>`` regions (complete + partial). 3. Build JSON args for each, diff against previous, emit deltas. Because the entire text is re-parsed each time, the result is correct regardless of how many tokens arrived in this step. """ # First chunk of a new stream — reset state. if not previous_text: self._reset_streaming_state() # If tools aren't enabled, just forward content. if not self._tools_enabled(request): return DeltaMessage(content=delta_text) if delta_text else None # 1. Extract any content outside tool-call regions. content = self._extract_content(current_text) # 2. Find all invoke regions. regions = self._extract_invoke_regions(current_text) tool_call_deltas: list[DeltaToolCall] = [] for i, (func_name, inner_text, is_complete) in enumerate(regions): self._ensure_tool_state_for(i) # Emit the tool name (once per tool call). if "name" not in self.prev_tool_call_arr[i]: self.prev_tool_call_arr[i]["name"] = func_name tool_call_deltas.append( DeltaToolCall( index=i, id=self._tool_call_ids[i], type="function", function=DeltaFunctionCall( name=func_name, arguments="", ), ) ) # Build the JSON args so far and emit the diff. args_so_far = self._build_args_json_so_far( func_name, inner_text, is_complete ) diff = self._compute_args_diff(i, args_so_far) if diff: tool_call_deltas.append( DeltaToolCall( index=i, function=DeltaFunctionCall(arguments=diff), ) ) if regions: self.current_tool_id = len(regions) - 1 # 3. Return a delta if we have content or tool-call updates. if content or tool_call_deltas: return DeltaMessage( content=content, tool_calls=tool_call_deltas, ) # Empty delta with token ids means EOS or closing tag — return # non-None so the serving framework can finalize finish_reason. if not delta_text and delta_token_ids and self.prev_tool_call_arr: return DeltaMessage(content="") return None