Files
vllm-deepseek-v32-mtp/deepseekv32_tool_parser.py

616 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
DeepSeek-V3.2 Tool Call Parser — re-parse-and-diff version.
Adapted from the GLM-4 streaming fix to make the streaming path robust
against multi-token deltas produced by MTP speculative decoding.
Instead of maintaining incremental state that advances one token at a
time, the streaming path re-parses the *entire* current_text on every
call, finds all <DSMLinvoke> regions (complete and in-progress),
builds a JSON arguments string for each, and diffs against what was
previously sent. This makes the parser agnostic to how many tokens
arrive per step.
Key changes vs. the upstream buffer-until-complete parser:
1. _extract_content() handles partial tag overlaps so content text
is never swallowed or duplicated when a tag boundary lands inside
a multi-token chunk.
2. _extract_invoke_regions() finds both complete and incomplete
invoke blocks, enabling streaming of partial arguments.
3. _build_args_json_so_far() constructs the JSON arguments string
incrementally from complete + partial <DSMLparameter> tags.
4. _compute_args_diff() emits only the newly-added characters.
Drop-in replacement: same class name, same interface.
"""
import json
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
logger = init_logger(__name__)
def partial_tag_overlap(text: str, tag: str) -> int:
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
E.g. text ending in ``"<tool_"`` returns 6 when tag is ``"<tool_call>"``.
Returns 0 when there is no overlap.
"""
max_check = min(len(tag) - 1, len(text))
for k in range(max_check, 0, -1):
if text.endswith(tag[:k]):
return k
return 0
class DeepSeekV32ToolParser(ToolParser):
"""
Re-parse-and-diff tool parser for DeepSeek-V3.2 DSML format.
On every streaming call the parser re-parses ``current_text`` to
find ``<DSMLinvoke>`` regions, builds the JSON arguments string
for each tool call, and diffs against what was previously sent to
emit only new content. This is robust against multi-token deltas
from MTP / EAGLE speculative decoding.
Example tool call format::
<DSMLfunction_calls>
<DSMLinvoke name="get_weather">
<DSMLparameter name="location" string="true">杭州</DSMLparameter>
<DSMLparameter name="date" string="true">2024-01-16</DSMLparameter>
</DSMLinvoke>
</DSMLfunction_calls>
"""
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# ----- Tag constants -----
self.tool_call_start_token: str = "<DSMLfunction_calls>"
self.tool_call_end_token: str = "</DSMLfunction_calls>"
self.invoke_end_token: str = "</DSMLinvoke>"
self.param_end_token: str = "</DSMLparameter>"
# Alias expected by ToolParser base / adjust_request
self.tool_calls_start_token = self.tool_call_start_token
# ----- Compiled regexes -----
# Matches a complete <DSMLfunction_calls>…</DSMLfunction_calls>
self.tool_call_complete_regex = re.compile(
r"<DSMLfunction_calls>(.*?)</DSMLfunction_calls>", re.DOTALL
)
# Opening tag of an invoke block — captures the function name.
self.invoke_start_regex = re.compile(
r'<DSMLinvoke\s+name="([^"]+)"\s*>', re.DOTALL
)
# Complete invoke block.
self.invoke_complete_regex = re.compile(
r'<DSMLinvoke\s+name="([^"]+)"\s*>(.*?)</DSMLinvoke>',
re.DOTALL,
)
# Complete parameter tag — captures (name, string_attr, value).
self.parameter_complete_regex = re.compile(
r'<DSMLparameter\s+name="([^"]+)"\s+string="(true|false)"\s*>'
r"(.*?)"
r"</DSMLparameter>",
re.DOTALL,
)
# Just the opening header of a parameter tag (for partial params).
self.parameter_header_regex = re.compile(
r'<DSMLparameter\s+name="([^"]+)"\s+string="(true|false)"\s*>',
re.DOTALL,
)
# ----- Streaming state (reset per request) -----
self._sent_content_idx: int = 0
self._tool_call_ids: list[str] = []
self.streamed_args_for_tool: list[str] = []
self.prev_tool_call_arr: list[dict[str, Any]] = []
self.current_tool_id: int = -1
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
logger.debug(
"Successfully initialized %s", self.__class__.__name__
)
# ------------------------------------------------------------------
# Request adjustment
# ------------------------------------------------------------------
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# Ensure DSML tokens are not stripped during decoding.
request.skip_special_tokens = False
return request
# ------------------------------------------------------------------
# Static / utility helpers
# ------------------------------------------------------------------
@staticmethod
def _tools_enabled(request: ChatCompletionRequest) -> bool:
"""Check whether tool calling is active for this request."""
try:
tools = getattr(request, "tools", None)
tool_choice = getattr(request, "tool_choice", None)
return bool(tools) and tool_choice != "none"
except Exception:
logger.exception("Failed to determine if tools are enabled.")
return False
def _generate_tool_call_id(self) -> str:
return f"call_{uuid.uuid4().hex[:24]}"
@staticmethod
def _json_escape_string_content(s: str) -> str:
"""JSON-escape a string value (without surrounding quotes)."""
if not s:
return ""
return json.dumps(s, ensure_ascii=False)[1:-1]
# ------------------------------------------------------------------
# Type conversion helpers
# ------------------------------------------------------------------
def _convert_param_value_checked(self, value: str, param_type: str) -> Any:
"""Convert a raw string value to the type indicated by *param_type*.
Raises on failure so the caller can try the next candidate type.
"""
if value.lower() == "null":
return None
param_type = param_type.lower()
if param_type in ("string", "str", "text"):
return value
elif param_type in ("integer", "int"):
return int(value)
elif param_type in ("number", "float"):
val = float(value)
return val if val != int(val) else int(val)
elif param_type in ("boolean", "bool"):
normed = value.strip().lower()
if normed not in ("false", "0", "true", "1"):
raise ValueError(f"Invalid boolean value: {value!r}")
return normed in ("true", "1")
elif param_type in ("object", "array"):
return json.loads(value)
else:
return json.loads(value)
def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any:
"""Try each candidate type in turn; fall back to the raw string."""
if not isinstance(param_type, list):
param_type = [param_type]
for current_type in param_type:
try:
return self._convert_param_value_checked(value, current_type)
except Exception:
continue
return value
def _get_param_schema_type(
self, func_name: str, param_name: str
) -> str | list[str]:
"""Look up the JSON-schema type for a parameter, defaulting to
``"string"``."""
if self.tools:
for tool in self.tools:
if (
hasattr(tool, "function")
and tool.function.name == func_name
and hasattr(tool.function, "parameters")
):
schema = tool.function.parameters
if isinstance(schema, dict) and "properties" in schema:
prop = schema["properties"].get(param_name, {})
if isinstance(prop, dict):
return prop.get("type", "string")
break
return "string"
def _convert_with_schema(
self, func_name: str, param_name: str, value: str
) -> Any:
"""Convert *value* using the tool schema for *func_name*.*param_name*."""
param_type = self._get_param_schema_type(func_name, param_name)
return self._convert_param_value(value, param_type)
def _is_string_type(self, func_name: str, param_name: str) -> bool:
"""Return True if the schema says this parameter is a string."""
ptype = self._get_param_schema_type(func_name, param_name)
if isinstance(ptype, list):
return "string" in ptype
return ptype in ("string", "str", "text")
# ------------------------------------------------------------------
# Non-streaming extraction (unchanged logic, shared helpers)
# ------------------------------------------------------------------
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""Extract tool calls from complete model output (non-streaming)."""
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
tool_calls: list[ToolCall] = []
for fc_block in self.tool_call_complete_regex.findall(model_output):
for invoke_name, invoke_body in self.invoke_complete_regex.findall(
fc_block
):
# Parse all parameters in this invoke.
raw_params: dict[str, str] = {}
for pname, _str_attr, pval in (
self.parameter_complete_regex.findall(invoke_body)
):
raw_params[pname] = pval
# Convert types via schema.
converted: dict[str, Any] = {}
for pname, pval in raw_params.items():
converted[pname] = self._convert_with_schema(
invoke_name, pname, pval
)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=invoke_name,
arguments=json.dumps(
converted, ensure_ascii=False
),
),
)
)
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
first_idx = model_output.find(self.tool_call_start_token)
content = model_output[:first_idx] if first_idx > 0 else None
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content
)
except Exception:
logger.exception("Error extracting tool calls from complete output")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# ------------------------------------------------------------------
# Streaming helpers — re-parse-and-diff
# ------------------------------------------------------------------
def _reset_streaming_state(self) -> None:
self._sent_content_idx = 0
self._tool_call_ids.clear()
self.streamed_args_for_tool.clear()
self.prev_tool_call_arr.clear()
self.current_tool_id = -1
def _extract_content(self, current_text: str) -> str | None:
"""Return any non-tool-call text that hasn't been sent yet.
Walks *current_text* from ``_sent_content_idx``, collecting text
outside ``<DSMLfunction_calls>`` regions. Uses
``partial_tag_overlap`` to avoid emitting bytes that might turn
out to be the start of the function-calls tag once the next
chunk arrives.
"""
content_segments: list[str] = []
pos = self._sent_content_idx
while pos < len(current_text):
start = current_text.find(self.tool_call_start_token, pos)
if start == -1:
# No (more) tool-call regions — send the tail, minus
# any suffix that could be the beginning of the tag.
tail = current_text[pos:]
overlap = partial_tag_overlap(tail, self.tool_call_start_token)
sendable = tail[: len(tail) - overlap] if overlap else tail
if sendable:
content_segments.append(sendable)
pos = len(current_text) - overlap
break
# Text between previous position and the tag start is content.
if start > pos:
content_segments.append(current_text[pos:start])
# Skip past the tool-call region.
end = current_text.find(self.tool_call_end_token, start)
if end != -1:
pos = end + len(self.tool_call_end_token)
else:
# Region still open — park cursor at start, stop.
pos = start
break
if content_segments:
self._sent_content_idx = pos
return "".join(content_segments)
if pos > self._sent_content_idx:
self._sent_content_idx = pos
return None
def _extract_invoke_regions(
self, text: str
) -> list[tuple[str, str, bool]]:
"""Find all invoke blocks inside the function_calls region.
Returns a list of ``(func_name, inner_text, is_complete)``
tuples. *inner_text* is everything between the invoke open
tag and the close tag (or the end of available text for the
last, potentially incomplete, invoke).
"""
results: list[tuple[str, str, bool]] = []
fc_start = text.find(self.tool_call_start_token)
if fc_start == -1:
return results
region_start = fc_start + len(self.tool_call_start_token)
fc_end = text.find(self.tool_call_end_token, region_start)
region = text[region_start:fc_end] if fc_end != -1 else text[region_start:]
pos = 0
while pos < len(region):
inv_match = self.invoke_start_regex.search(region, pos)
if not inv_match:
break
func_name = inv_match.group(1)
body_start = inv_match.end()
inv_end_pos = region.find(self.invoke_end_token, body_start)
if inv_end_pos != -1:
# Complete invoke block.
body = region[body_start:inv_end_pos]
results.append((func_name, body, True))
pos = inv_end_pos + len(self.invoke_end_token)
else:
# Incomplete — still being generated.
body = region[body_start:]
overlap = partial_tag_overlap(body, self.invoke_end_token)
if overlap:
body = body[:-overlap]
results.append((func_name, body, False))
break
return results
def _build_args_json_so_far(
self,
func_name: str,
inner_text: str,
is_complete: bool,
) -> str:
"""Build a JSON arguments string from the parameters found so far.
Handles both fully-closed ``<DSMLparameter>`` tags and the
single trailing partial parameter whose value is still being
streamed.
"""
# ---- Collect all fully-closed parameters ----
complete_params = self.parameter_complete_regex.findall(inner_text)
parts: list[str] = []
for param_name, string_attr, param_value in complete_params:
key_json = json.dumps(param_name, ensure_ascii=False)
if string_attr == "true":
val_json = json.dumps(param_value, ensure_ascii=False)
else:
converted = self._convert_with_schema(
func_name, param_name, param_value
)
val_json = json.dumps(converted, ensure_ascii=False)
parts.append(f"{key_json}: {val_json}")
# ---- Handle a trailing partial parameter ----
last_param_open = inner_text.rfind("<DSMLparameter")
last_param_close = inner_text.rfind(self.param_end_token)
has_partial = last_param_open != -1 and (
last_param_close == -1 or last_param_close < last_param_open
)
if has_partial:
partial_text = inner_text[last_param_open:]
header_match = self.parameter_header_regex.search(partial_text)
if header_match:
param_name = header_match.group(1)
string_attr = header_match.group(2)
partial_value = partial_text[header_match.end():]
# Strip any bytes that might be the beginning of the
# closing </DSMLparameter> tag.
overlap = partial_tag_overlap(
partial_value, self.param_end_token
)
if overlap:
partial_value = partial_value[:-overlap]
key_json = json.dumps(param_name, ensure_ascii=False)
if is_complete:
# Invoke is closed — treat whatever we have as final.
if string_attr == "true":
val_json = json.dumps(
partial_value, ensure_ascii=False
)
else:
converted = self._convert_with_schema(
func_name, param_name, partial_value
)
val_json = json.dumps(converted, ensure_ascii=False)
parts.append(f"{key_json}: {val_json}")
elif string_attr == "true" or self._is_string_type(
func_name, param_name
):
# Stream as an open JSON string (no closing quote).
escaped = self._json_escape_string_content(partial_value)
parts.append(f'{key_json}: "{escaped}')
else:
# Non-string — emit raw partial value.
parts.append(f"{key_json}: {partial_value}")
# ---- Assemble ----
if not parts:
return "{}" if is_complete else ""
joined = "{" + ", ".join(parts)
if is_complete:
joined += "}"
return joined
def _compute_args_diff(self, index: int, args_so_far: str) -> str | None:
"""Return only the characters in *args_so_far* that haven't been
sent yet, or ``None`` if there's nothing new."""
prev = self.streamed_args_for_tool[index]
if not args_so_far or len(args_so_far) <= len(prev):
return None
diff = args_so_far[len(prev):]
self.streamed_args_for_tool[index] = args_so_far
self.prev_tool_call_arr[index]["arguments"] = args_so_far
return diff
def _ensure_tool_state_for(self, index: int) -> None:
"""Grow the streaming-state arrays so *index* is valid."""
while len(self._tool_call_ids) <= index:
self._tool_call_ids.append(self._generate_tool_call_id())
while len(self.streamed_args_for_tool) <= index:
self.streamed_args_for_tool.append("")
while len(self.prev_tool_call_arr) <= index:
self.prev_tool_call_arr.append({})
# ------------------------------------------------------------------
# Main streaming entry point
# ------------------------------------------------------------------
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming output using re-parse-and-diff.
On every call we:
1. Re-scan *current_text* for content outside tool-call regions.
2. Find all ``<DSMLinvoke>`` regions (complete + partial).
3. Build JSON args for each, diff against previous, emit deltas.
Because the entire text is re-parsed each time, the result is
correct regardless of how many tokens arrived in this step.
"""
# First chunk of a new stream — reset state.
if not previous_text:
self._reset_streaming_state()
# If tools aren't enabled, just forward content.
if not self._tools_enabled(request):
return DeltaMessage(content=delta_text) if delta_text else None
# 1. Extract any content outside tool-call regions.
content = self._extract_content(current_text)
# 2. Find all invoke regions.
regions = self._extract_invoke_regions(current_text)
tool_call_deltas: list[DeltaToolCall] = []
for i, (func_name, inner_text, is_complete) in enumerate(regions):
self._ensure_tool_state_for(i)
# Emit the tool name (once per tool call).
if "name" not in self.prev_tool_call_arr[i]:
self.prev_tool_call_arr[i]["name"] = func_name
tool_call_deltas.append(
DeltaToolCall(
index=i,
id=self._tool_call_ids[i],
type="function",
function=DeltaFunctionCall(
name=func_name,
arguments="",
),
)
)
# Build the JSON args so far and emit the diff.
args_so_far = self._build_args_json_so_far(
func_name, inner_text, is_complete
)
diff = self._compute_args_diff(i, args_so_far)
if diff:
tool_call_deltas.append(
DeltaToolCall(
index=i,
function=DeltaFunctionCall(arguments=diff),
)
)
if regions:
self.current_tool_id = len(regions) - 1
# 3. Return a delta if we have content or tool-call updates.
if content or tool_call_deltas:
return DeltaMessage(
content=content,
tool_calls=tool_call_deltas,
)
# Empty delta with token ids means EOS or closing tag — return
# non-None so the serving framework can finalize finish_reason.
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
return DeltaMessage(content="")
return None