Files
vllm-deepseek-v32-mtp/deepseekv32_tool_parser.py

616 lines
24 KiB
Python
Raw Normal View History

2026-04-13 23:42:31 +00:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
DeepSeek-V3.2 Tool Call Parser re-parse-and-diff version.
Adapted from the GLM-4 streaming fix to make the streaming path robust
against multi-token deltas produced by MTP speculative decoding.
Instead of maintaining incremental state that advances one token at a
time, the streaming path re-parses the *entire* current_text on every
call, finds all <DSMLinvoke> regions (complete and in-progress),
builds a JSON arguments string for each, and diffs against what was
previously sent. This makes the parser agnostic to how many tokens
arrive per step.
Key changes vs. the upstream buffer-until-complete parser:
1. _extract_content() handles partial tag overlaps so content text
is never swallowed or duplicated when a tag boundary lands inside
a multi-token chunk.
2. _extract_invoke_regions() finds both complete and incomplete
invoke blocks, enabling streaming of partial arguments.
3. _build_args_json_so_far() constructs the JSON arguments string
incrementally from complete + partial <DSMLparameter> tags.
4. _compute_args_diff() emits only the newly-added characters.
Drop-in replacement: same class name, same interface.
"""
import json
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
logger = init_logger(__name__)
def partial_tag_overlap(text: str, tag: str) -> int:
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
E.g. text ending in ``"<tool_"`` returns 6 when tag is ``"<tool_call>"``.
Returns 0 when there is no overlap.
"""
max_check = min(len(tag) - 1, len(text))
for k in range(max_check, 0, -1):
if text.endswith(tag[:k]):
return k
return 0
2026-04-13 23:42:31 +00:00
class DeepSeekV32ToolParser(ToolParser):
"""
Re-parse-and-diff tool parser for DeepSeek-V3.2 DSML format.
On every streaming call the parser re-parses ``current_text`` to
find ``<DSMLinvoke>`` regions, builds the JSON arguments string
for each tool call, and diffs against what was previously sent to
emit only new content. This is robust against multi-token deltas
from MTP / EAGLE speculative decoding.
Example tool call format::
<DSMLfunction_calls>
<DSMLinvoke name="get_weather">
<DSMLparameter name="location" string="true">杭州</DSMLparameter>
<DSMLparameter name="date" string="true">2024-01-16</DSMLparameter>
</DSMLinvoke>
</DSMLfunction_calls>
"""
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# ----- Tag constants -----
self.tool_call_start_token: str = "<DSMLfunction_calls>"
self.tool_call_end_token: str = "</DSMLfunction_calls>"
self.invoke_end_token: str = "</DSMLinvoke>"
self.param_end_token: str = "</DSMLparameter>"
# Alias expected by ToolParser base / adjust_request
self.tool_calls_start_token = self.tool_call_start_token
# ----- Compiled regexes -----
# Matches a complete <DSMLfunction_calls>…</DSMLfunction_calls>
self.tool_call_complete_regex = re.compile(
r"<DSMLfunction_calls>(.*?)</DSMLfunction_calls>", re.DOTALL
)
# Opening tag of an invoke block — captures the function name.
self.invoke_start_regex = re.compile(
r'<DSMLinvoke\s+name="([^"]+)"\s*>', re.DOTALL
)
# Complete invoke block.
self.invoke_complete_regex = re.compile(
r'<DSMLinvoke\s+name="([^"]+)"\s*>(.*?)</DSMLinvoke>',
re.DOTALL,
)
# Complete parameter tag — captures (name, string_attr, value).
self.parameter_complete_regex = re.compile(
r'<DSMLparameter\s+name="([^"]+)"\s+string="(true|false)"\s*>'
r"(.*?)"
r"</DSMLparameter>",
re.DOTALL,
)
# Just the opening header of a parameter tag (for partial params).
self.parameter_header_regex = re.compile(
r'<DSMLparameter\s+name="([^"]+)"\s+string="(true|false)"\s*>',
re.DOTALL,
)
# ----- Streaming state (reset per request) -----
self._sent_content_idx: int = 0
self._tool_call_ids: list[str] = []
self.streamed_args_for_tool: list[str] = []
self.prev_tool_call_arr: list[dict[str, Any]] = []
self.current_tool_id: int = -1
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
logger.debug(
"Successfully initialized %s", self.__class__.__name__
)
# ------------------------------------------------------------------
# Request adjustment
# ------------------------------------------------------------------
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# Ensure DSML tokens are not stripped during decoding.
request.skip_special_tokens = False
return request
# ------------------------------------------------------------------
# Static / utility helpers
# ------------------------------------------------------------------
@staticmethod
def _tools_enabled(request: ChatCompletionRequest) -> bool:
"""Check whether tool calling is active for this request."""
try:
tools = getattr(request, "tools", None)
tool_choice = getattr(request, "tool_choice", None)
return bool(tools) and tool_choice != "none"
except Exception:
logger.exception("Failed to determine if tools are enabled.")
return False
def _generate_tool_call_id(self) -> str:
return f"call_{uuid.uuid4().hex[:24]}"
@staticmethod
def _json_escape_string_content(s: str) -> str:
"""JSON-escape a string value (without surrounding quotes)."""
if not s:
return ""
return json.dumps(s, ensure_ascii=False)[1:-1]
# ------------------------------------------------------------------
# Type conversion helpers
# ------------------------------------------------------------------
def _convert_param_value_checked(self, value: str, param_type: str) -> Any:
"""Convert a raw string value to the type indicated by *param_type*.
Raises on failure so the caller can try the next candidate type.
"""
if value.lower() == "null":
return None
param_type = param_type.lower()
if param_type in ("string", "str", "text"):
return value
elif param_type in ("integer", "int"):
return int(value)
elif param_type in ("number", "float"):
val = float(value)
return val if val != int(val) else int(val)
elif param_type in ("boolean", "bool"):
normed = value.strip().lower()
if normed not in ("false", "0", "true", "1"):
raise ValueError(f"Invalid boolean value: {value!r}")
return normed in ("true", "1")
elif param_type in ("object", "array"):
return json.loads(value)
else:
return json.loads(value)
def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any:
"""Try each candidate type in turn; fall back to the raw string."""
if not isinstance(param_type, list):
param_type = [param_type]
for current_type in param_type:
try:
return self._convert_param_value_checked(value, current_type)
except Exception:
continue
return value
def _get_param_schema_type(
self, func_name: str, param_name: str
) -> str | list[str]:
"""Look up the JSON-schema type for a parameter, defaulting to
``"string"``."""
if self.tools:
for tool in self.tools:
if (
hasattr(tool, "function")
and tool.function.name == func_name
and hasattr(tool.function, "parameters")
):
schema = tool.function.parameters
if isinstance(schema, dict) and "properties" in schema:
prop = schema["properties"].get(param_name, {})
if isinstance(prop, dict):
return prop.get("type", "string")
break
return "string"
def _convert_with_schema(
self, func_name: str, param_name: str, value: str
) -> Any:
"""Convert *value* using the tool schema for *func_name*.*param_name*."""
param_type = self._get_param_schema_type(func_name, param_name)
return self._convert_param_value(value, param_type)
def _is_string_type(self, func_name: str, param_name: str) -> bool:
"""Return True if the schema says this parameter is a string."""
ptype = self._get_param_schema_type(func_name, param_name)
if isinstance(ptype, list):
return "string" in ptype
return ptype in ("string", "str", "text")
# ------------------------------------------------------------------
# Non-streaming extraction (unchanged logic, shared helpers)
# ------------------------------------------------------------------
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""Extract tool calls from complete model output (non-streaming)."""
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
tool_calls: list[ToolCall] = []
for fc_block in self.tool_call_complete_regex.findall(model_output):
for invoke_name, invoke_body in self.invoke_complete_regex.findall(
fc_block
):
# Parse all parameters in this invoke.
raw_params: dict[str, str] = {}
for pname, _str_attr, pval in (
self.parameter_complete_regex.findall(invoke_body)
):
raw_params[pname] = pval
# Convert types via schema.
converted: dict[str, Any] = {}
for pname, pval in raw_params.items():
converted[pname] = self._convert_with_schema(
invoke_name, pname, pval
)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=invoke_name,
arguments=json.dumps(
converted, ensure_ascii=False
),
),
)
)
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
first_idx = model_output.find(self.tool_call_start_token)
content = model_output[:first_idx] if first_idx > 0 else None
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content
)
except Exception:
logger.exception("Error extracting tool calls from complete output")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# ------------------------------------------------------------------
# Streaming helpers — re-parse-and-diff
# ------------------------------------------------------------------
def _reset_streaming_state(self) -> None:
self._sent_content_idx = 0
self._tool_call_ids.clear()
self.streamed_args_for_tool.clear()
self.prev_tool_call_arr.clear()
self.current_tool_id = -1
def _extract_content(self, current_text: str) -> str | None:
"""Return any non-tool-call text that hasn't been sent yet.
Walks *current_text* from ``_sent_content_idx``, collecting text
outside ``<DSMLfunction_calls>`` regions. Uses
``partial_tag_overlap`` to avoid emitting bytes that might turn
out to be the start of the function-calls tag once the next
chunk arrives.
"""
content_segments: list[str] = []
pos = self._sent_content_idx
while pos < len(current_text):
start = current_text.find(self.tool_call_start_token, pos)
if start == -1:
# No (more) tool-call regions — send the tail, minus
# any suffix that could be the beginning of the tag.
tail = current_text[pos:]
overlap = partial_tag_overlap(tail, self.tool_call_start_token)
sendable = tail[: len(tail) - overlap] if overlap else tail
if sendable:
content_segments.append(sendable)
pos = len(current_text) - overlap
break
# Text between previous position and the tag start is content.
if start > pos:
content_segments.append(current_text[pos:start])
# Skip past the tool-call region.
end = current_text.find(self.tool_call_end_token, start)
if end != -1:
pos = end + len(self.tool_call_end_token)
else:
# Region still open — park cursor at start, stop.
pos = start
break
if content_segments:
self._sent_content_idx = pos
return "".join(content_segments)
if pos > self._sent_content_idx:
self._sent_content_idx = pos
return None
def _extract_invoke_regions(
self, text: str
) -> list[tuple[str, str, bool]]:
"""Find all invoke blocks inside the function_calls region.
Returns a list of ``(func_name, inner_text, is_complete)``
tuples. *inner_text* is everything between the invoke open
tag and the close tag (or the end of available text for the
last, potentially incomplete, invoke).
"""
results: list[tuple[str, str, bool]] = []
fc_start = text.find(self.tool_call_start_token)
if fc_start == -1:
return results
region_start = fc_start + len(self.tool_call_start_token)
fc_end = text.find(self.tool_call_end_token, region_start)
region = text[region_start:fc_end] if fc_end != -1 else text[region_start:]
pos = 0
while pos < len(region):
inv_match = self.invoke_start_regex.search(region, pos)
if not inv_match:
break
func_name = inv_match.group(1)
body_start = inv_match.end()
inv_end_pos = region.find(self.invoke_end_token, body_start)
if inv_end_pos != -1:
# Complete invoke block.
body = region[body_start:inv_end_pos]
results.append((func_name, body, True))
pos = inv_end_pos + len(self.invoke_end_token)
else:
# Incomplete — still being generated.
body = region[body_start:]
overlap = partial_tag_overlap(body, self.invoke_end_token)
if overlap:
body = body[:-overlap]
results.append((func_name, body, False))
break
return results
def _build_args_json_so_far(
self,
func_name: str,
inner_text: str,
is_complete: bool,
) -> str:
"""Build a JSON arguments string from the parameters found so far.
Handles both fully-closed ``<DSMLparameter>`` tags and the
single trailing partial parameter whose value is still being
streamed.
"""
# ---- Collect all fully-closed parameters ----
complete_params = self.parameter_complete_regex.findall(inner_text)
parts: list[str] = []
for param_name, string_attr, param_value in complete_params:
key_json = json.dumps(param_name, ensure_ascii=False)
if string_attr == "true":
val_json = json.dumps(param_value, ensure_ascii=False)
else:
converted = self._convert_with_schema(
func_name, param_name, param_value
)
val_json = json.dumps(converted, ensure_ascii=False)
parts.append(f"{key_json}: {val_json}")
# ---- Handle a trailing partial parameter ----
last_param_open = inner_text.rfind("<DSMLparameter")
last_param_close = inner_text.rfind(self.param_end_token)
has_partial = last_param_open != -1 and (
last_param_close == -1 or last_param_close < last_param_open
)
if has_partial:
partial_text = inner_text[last_param_open:]
header_match = self.parameter_header_regex.search(partial_text)
if header_match:
param_name = header_match.group(1)
string_attr = header_match.group(2)
partial_value = partial_text[header_match.end():]
# Strip any bytes that might be the beginning of the
# closing </DSMLparameter> tag.
overlap = partial_tag_overlap(
partial_value, self.param_end_token
)
if overlap:
partial_value = partial_value[:-overlap]
key_json = json.dumps(param_name, ensure_ascii=False)
if is_complete:
# Invoke is closed — treat whatever we have as final.
if string_attr == "true":
val_json = json.dumps(
partial_value, ensure_ascii=False
)
else:
converted = self._convert_with_schema(
func_name, param_name, partial_value
)
val_json = json.dumps(converted, ensure_ascii=False)
parts.append(f"{key_json}: {val_json}")
elif string_attr == "true" or self._is_string_type(
func_name, param_name
):
# Stream as an open JSON string (no closing quote).
escaped = self._json_escape_string_content(partial_value)
parts.append(f'{key_json}: "{escaped}')
else:
# Non-string — emit raw partial value.
parts.append(f"{key_json}: {partial_value}")
# ---- Assemble ----
if not parts:
return "{}" if is_complete else ""
joined = "{" + ", ".join(parts)
if is_complete:
joined += "}"
return joined
def _compute_args_diff(self, index: int, args_so_far: str) -> str | None:
"""Return only the characters in *args_so_far* that haven't been
sent yet, or ``None`` if there's nothing new."""
prev = self.streamed_args_for_tool[index]
if not args_so_far or len(args_so_far) <= len(prev):
return None
diff = args_so_far[len(prev):]
self.streamed_args_for_tool[index] = args_so_far
self.prev_tool_call_arr[index]["arguments"] = args_so_far
return diff
def _ensure_tool_state_for(self, index: int) -> None:
"""Grow the streaming-state arrays so *index* is valid."""
while len(self._tool_call_ids) <= index:
self._tool_call_ids.append(self._generate_tool_call_id())
while len(self.streamed_args_for_tool) <= index:
self.streamed_args_for_tool.append("")
while len(self.prev_tool_call_arr) <= index:
self.prev_tool_call_arr.append({})
# ------------------------------------------------------------------
# Main streaming entry point
# ------------------------------------------------------------------
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming output using re-parse-and-diff.
On every call we:
1. Re-scan *current_text* for content outside tool-call regions.
2. Find all ``<DSMLinvoke>`` regions (complete + partial).
3. Build JSON args for each, diff against previous, emit deltas.
Because the entire text is re-parsed each time, the result is
correct regardless of how many tokens arrived in this step.
"""
# First chunk of a new stream — reset state.
if not previous_text:
self._reset_streaming_state()
# If tools aren't enabled, just forward content.
if not self._tools_enabled(request):
return DeltaMessage(content=delta_text) if delta_text else None
# 1. Extract any content outside tool-call regions.
content = self._extract_content(current_text)
# 2. Find all invoke regions.
regions = self._extract_invoke_regions(current_text)
tool_call_deltas: list[DeltaToolCall] = []
for i, (func_name, inner_text, is_complete) in enumerate(regions):
self._ensure_tool_state_for(i)
# Emit the tool name (once per tool call).
if "name" not in self.prev_tool_call_arr[i]:
self.prev_tool_call_arr[i]["name"] = func_name
tool_call_deltas.append(
DeltaToolCall(
index=i,
id=self._tool_call_ids[i],
type="function",
function=DeltaFunctionCall(
name=func_name,
arguments="",
),
)
)
# Build the JSON args so far and emit the diff.
args_so_far = self._build_args_json_so_far(
func_name, inner_text, is_complete
)
diff = self._compute_args_diff(i, args_so_far)
if diff:
tool_call_deltas.append(
DeltaToolCall(
index=i,
function=DeltaFunctionCall(arguments=diff),
)
)
if regions:
self.current_tool_id = len(regions) - 1
# 3. Return a delta if we have content or tool-call updates.
if content or tool_call_deltas:
return DeltaMessage(
content=content,
tool_calls=tool_call_deltas,
)
# Empty delta with token ids means EOS or closing tag — return
# non-None so the serving framework can finalize finish_reason.
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
return DeltaMessage(content="")
return None