616 lines
24 KiB
Python
616 lines
24 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
"""
|
||
DeepSeek-V3.2 Tool Call Parser — re-parse-and-diff version.
|
||
|
||
Adapted from the GLM-4 streaming fix to make the streaming path robust
|
||
against multi-token deltas produced by MTP speculative decoding.
|
||
|
||
Instead of maintaining incremental state that advances one token at a
|
||
time, the streaming path re-parses the *entire* current_text on every
|
||
call, finds all <|DSML|invoke> regions (complete and in-progress),
|
||
builds a JSON arguments string for each, and diffs against what was
|
||
previously sent. This makes the parser agnostic to how many tokens
|
||
arrive per step.
|
||
|
||
Key changes vs. the upstream buffer-until-complete parser:
|
||
1. _extract_content() handles partial tag overlaps so content text
|
||
is never swallowed or duplicated when a tag boundary lands inside
|
||
a multi-token chunk.
|
||
2. _extract_invoke_regions() finds both complete and incomplete
|
||
invoke blocks, enabling streaming of partial arguments.
|
||
3. _build_args_json_so_far() constructs the JSON arguments string
|
||
incrementally from complete + partial <|DSML|parameter> tags.
|
||
4. _compute_args_diff() emits only the newly-added characters.
|
||
|
||
Drop-in replacement: same class name, same interface.
|
||
"""
|
||
|
||
import json
|
||
import uuid
|
||
from collections.abc import Sequence
|
||
from typing import Any
|
||
|
||
import regex as re
|
||
|
||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||
ChatCompletionRequest,
|
||
)
|
||
from vllm.entrypoints.openai.engine.protocol import (
|
||
DeltaFunctionCall,
|
||
DeltaMessage,
|
||
DeltaToolCall,
|
||
ExtractedToolCallInformation,
|
||
FunctionCall,
|
||
ToolCall,
|
||
)
|
||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||
from vllm.logger import init_logger
|
||
from vllm.tokenizers import TokenizerLike
|
||
from vllm.tool_parsers.abstract_tool_parser import (
|
||
Tool,
|
||
ToolParser,
|
||
)
|
||
logger = init_logger(__name__)
|
||
|
||
|
||
def partial_tag_overlap(text: str, tag: str) -> int:
|
||
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
|
||
|
||
E.g. text ending in ``"<tool_"`` returns 6 when tag is ``"<tool_call>"``.
|
||
Returns 0 when there is no overlap.
|
||
"""
|
||
max_check = min(len(tag) - 1, len(text))
|
||
for k in range(max_check, 0, -1):
|
||
if text.endswith(tag[:k]):
|
||
return k
|
||
return 0
|
||
|
||
|
||
class DeepSeekV32ToolParser(ToolParser):
|
||
"""
|
||
Re-parse-and-diff tool parser for DeepSeek-V3.2 DSML format.
|
||
|
||
On every streaming call the parser re-parses ``current_text`` to
|
||
find ``<|DSML|invoke>`` regions, builds the JSON arguments string
|
||
for each tool call, and diffs against what was previously sent to
|
||
emit only new content. This is robust against multi-token deltas
|
||
from MTP / EAGLE speculative decoding.
|
||
|
||
Example tool call format::
|
||
|
||
<|DSML|function_calls>
|
||
<|DSML|invoke name="get_weather">
|
||
<|DSML|parameter name="location" string="true">杭州</|DSML|parameter>
|
||
<|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter>
|
||
</|DSML|invoke>
|
||
</|DSML|function_calls>
|
||
"""
|
||
|
||
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
||
super().__init__(tokenizer, tools)
|
||
|
||
# ----- Tag constants -----
|
||
self.tool_call_start_token: str = "<|DSML|function_calls>"
|
||
self.tool_call_end_token: str = "</|DSML|function_calls>"
|
||
self.invoke_end_token: str = "</|DSML|invoke>"
|
||
self.param_end_token: str = "</|DSML|parameter>"
|
||
|
||
# Alias expected by ToolParser base / adjust_request
|
||
self.tool_calls_start_token = self.tool_call_start_token
|
||
|
||
# ----- Compiled regexes -----
|
||
# Matches a complete <|DSML|function_calls>…</|DSML|function_calls>
|
||
self.tool_call_complete_regex = re.compile(
|
||
r"<|DSML|function_calls>(.*?)</|DSML|function_calls>", re.DOTALL
|
||
)
|
||
# Opening tag of an invoke block — captures the function name.
|
||
self.invoke_start_regex = re.compile(
|
||
r'<|DSML|invoke\s+name="([^"]+)"\s*>', re.DOTALL
|
||
)
|
||
# Complete invoke block.
|
||
self.invoke_complete_regex = re.compile(
|
||
r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>',
|
||
re.DOTALL,
|
||
)
|
||
# Complete parameter tag — captures (name, string_attr, value).
|
||
self.parameter_complete_regex = re.compile(
|
||
r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>'
|
||
r"(.*?)"
|
||
r"</|DSML|parameter>",
|
||
re.DOTALL,
|
||
)
|
||
# Just the opening header of a parameter tag (for partial params).
|
||
self.parameter_header_regex = re.compile(
|
||
r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>',
|
||
re.DOTALL,
|
||
)
|
||
|
||
# ----- Streaming state (reset per request) -----
|
||
self._sent_content_idx: int = 0
|
||
self._tool_call_ids: list[str] = []
|
||
self.streamed_args_for_tool: list[str] = []
|
||
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
||
self.current_tool_id: int = -1
|
||
|
||
if not self.model_tokenizer:
|
||
raise ValueError(
|
||
"The model tokenizer must be passed to the ToolParser "
|
||
"constructor during construction."
|
||
)
|
||
|
||
logger.debug(
|
||
"Successfully initialized %s", self.__class__.__name__
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Request adjustment
|
||
# ------------------------------------------------------------------
|
||
|
||
def adjust_request(
|
||
self, request: ChatCompletionRequest | ResponsesRequest
|
||
) -> ChatCompletionRequest | ResponsesRequest:
|
||
request = super().adjust_request(request)
|
||
if request.tools and request.tool_choice != "none":
|
||
# Ensure DSML tokens are not stripped during decoding.
|
||
request.skip_special_tokens = False
|
||
return request
|
||
|
||
# ------------------------------------------------------------------
|
||
# Static / utility helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
@staticmethod
|
||
def _tools_enabled(request: ChatCompletionRequest) -> bool:
|
||
"""Check whether tool calling is active for this request."""
|
||
try:
|
||
tools = getattr(request, "tools", None)
|
||
tool_choice = getattr(request, "tool_choice", None)
|
||
return bool(tools) and tool_choice != "none"
|
||
except Exception:
|
||
logger.exception("Failed to determine if tools are enabled.")
|
||
return False
|
||
|
||
def _generate_tool_call_id(self) -> str:
|
||
return f"call_{uuid.uuid4().hex[:24]}"
|
||
|
||
@staticmethod
|
||
def _json_escape_string_content(s: str) -> str:
|
||
"""JSON-escape a string value (without surrounding quotes)."""
|
||
if not s:
|
||
return ""
|
||
return json.dumps(s, ensure_ascii=False)[1:-1]
|
||
|
||
# ------------------------------------------------------------------
|
||
# Type conversion helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
def _convert_param_value_checked(self, value: str, param_type: str) -> Any:
|
||
"""Convert a raw string value to the type indicated by *param_type*.
|
||
|
||
Raises on failure so the caller can try the next candidate type.
|
||
"""
|
||
if value.lower() == "null":
|
||
return None
|
||
|
||
param_type = param_type.lower()
|
||
if param_type in ("string", "str", "text"):
|
||
return value
|
||
elif param_type in ("integer", "int"):
|
||
return int(value)
|
||
elif param_type in ("number", "float"):
|
||
val = float(value)
|
||
return val if val != int(val) else int(val)
|
||
elif param_type in ("boolean", "bool"):
|
||
normed = value.strip().lower()
|
||
if normed not in ("false", "0", "true", "1"):
|
||
raise ValueError(f"Invalid boolean value: {value!r}")
|
||
return normed in ("true", "1")
|
||
elif param_type in ("object", "array"):
|
||
return json.loads(value)
|
||
else:
|
||
return json.loads(value)
|
||
|
||
def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any:
|
||
"""Try each candidate type in turn; fall back to the raw string."""
|
||
if not isinstance(param_type, list):
|
||
param_type = [param_type]
|
||
for current_type in param_type:
|
||
try:
|
||
return self._convert_param_value_checked(value, current_type)
|
||
except Exception:
|
||
continue
|
||
return value
|
||
|
||
def _get_param_schema_type(
|
||
self, func_name: str, param_name: str
|
||
) -> str | list[str]:
|
||
"""Look up the JSON-schema type for a parameter, defaulting to
|
||
``"string"``."""
|
||
if self.tools:
|
||
for tool in self.tools:
|
||
if (
|
||
hasattr(tool, "function")
|
||
and tool.function.name == func_name
|
||
and hasattr(tool.function, "parameters")
|
||
):
|
||
schema = tool.function.parameters
|
||
if isinstance(schema, dict) and "properties" in schema:
|
||
prop = schema["properties"].get(param_name, {})
|
||
if isinstance(prop, dict):
|
||
return prop.get("type", "string")
|
||
break
|
||
return "string"
|
||
|
||
def _convert_with_schema(
|
||
self, func_name: str, param_name: str, value: str
|
||
) -> Any:
|
||
"""Convert *value* using the tool schema for *func_name*.*param_name*."""
|
||
param_type = self._get_param_schema_type(func_name, param_name)
|
||
return self._convert_param_value(value, param_type)
|
||
|
||
def _is_string_type(self, func_name: str, param_name: str) -> bool:
|
||
"""Return True if the schema says this parameter is a string."""
|
||
ptype = self._get_param_schema_type(func_name, param_name)
|
||
if isinstance(ptype, list):
|
||
return "string" in ptype
|
||
return ptype in ("string", "str", "text")
|
||
|
||
# ------------------------------------------------------------------
|
||
# Non-streaming extraction (unchanged logic, shared helpers)
|
||
# ------------------------------------------------------------------
|
||
|
||
def extract_tool_calls(
|
||
self,
|
||
model_output: str,
|
||
request: ChatCompletionRequest,
|
||
) -> ExtractedToolCallInformation:
|
||
"""Extract tool calls from complete model output (non-streaming)."""
|
||
if self.tool_call_start_token not in model_output:
|
||
return ExtractedToolCallInformation(
|
||
tools_called=False, tool_calls=[], content=model_output
|
||
)
|
||
|
||
try:
|
||
tool_calls: list[ToolCall] = []
|
||
|
||
for fc_block in self.tool_call_complete_regex.findall(model_output):
|
||
for invoke_name, invoke_body in self.invoke_complete_regex.findall(
|
||
fc_block
|
||
):
|
||
# Parse all parameters in this invoke.
|
||
raw_params: dict[str, str] = {}
|
||
for pname, _str_attr, pval in (
|
||
self.parameter_complete_regex.findall(invoke_body)
|
||
):
|
||
raw_params[pname] = pval
|
||
|
||
# Convert types via schema.
|
||
converted: dict[str, Any] = {}
|
||
for pname, pval in raw_params.items():
|
||
converted[pname] = self._convert_with_schema(
|
||
invoke_name, pname, pval
|
||
)
|
||
|
||
tool_calls.append(
|
||
ToolCall(
|
||
type="function",
|
||
function=FunctionCall(
|
||
name=invoke_name,
|
||
arguments=json.dumps(
|
||
converted, ensure_ascii=False
|
||
),
|
||
),
|
||
)
|
||
)
|
||
|
||
if not tool_calls:
|
||
return ExtractedToolCallInformation(
|
||
tools_called=False, tool_calls=[], content=model_output
|
||
)
|
||
|
||
first_idx = model_output.find(self.tool_call_start_token)
|
||
content = model_output[:first_idx] if first_idx > 0 else None
|
||
|
||
return ExtractedToolCallInformation(
|
||
tools_called=True, tool_calls=tool_calls, content=content
|
||
)
|
||
|
||
except Exception:
|
||
logger.exception("Error extracting tool calls from complete output")
|
||
return ExtractedToolCallInformation(
|
||
tools_called=False, tool_calls=[], content=model_output
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Streaming helpers — re-parse-and-diff
|
||
# ------------------------------------------------------------------
|
||
|
||
def _reset_streaming_state(self) -> None:
|
||
self._sent_content_idx = 0
|
||
self._tool_call_ids.clear()
|
||
self.streamed_args_for_tool.clear()
|
||
self.prev_tool_call_arr.clear()
|
||
self.current_tool_id = -1
|
||
|
||
def _extract_content(self, current_text: str) -> str | None:
|
||
"""Return any non-tool-call text that hasn't been sent yet.
|
||
|
||
Walks *current_text* from ``_sent_content_idx``, collecting text
|
||
outside ``<|DSML|function_calls>`` regions. Uses
|
||
``partial_tag_overlap`` to avoid emitting bytes that might turn
|
||
out to be the start of the function-calls tag once the next
|
||
chunk arrives.
|
||
"""
|
||
content_segments: list[str] = []
|
||
pos = self._sent_content_idx
|
||
|
||
while pos < len(current_text):
|
||
start = current_text.find(self.tool_call_start_token, pos)
|
||
if start == -1:
|
||
# No (more) tool-call regions — send the tail, minus
|
||
# any suffix that could be the beginning of the tag.
|
||
tail = current_text[pos:]
|
||
overlap = partial_tag_overlap(tail, self.tool_call_start_token)
|
||
sendable = tail[: len(tail) - overlap] if overlap else tail
|
||
if sendable:
|
||
content_segments.append(sendable)
|
||
pos = len(current_text) - overlap
|
||
break
|
||
|
||
# Text between previous position and the tag start is content.
|
||
if start > pos:
|
||
content_segments.append(current_text[pos:start])
|
||
|
||
# Skip past the tool-call region.
|
||
end = current_text.find(self.tool_call_end_token, start)
|
||
if end != -1:
|
||
pos = end + len(self.tool_call_end_token)
|
||
else:
|
||
# Region still open — park cursor at start, stop.
|
||
pos = start
|
||
break
|
||
|
||
if content_segments:
|
||
self._sent_content_idx = pos
|
||
return "".join(content_segments)
|
||
if pos > self._sent_content_idx:
|
||
self._sent_content_idx = pos
|
||
return None
|
||
|
||
def _extract_invoke_regions(
|
||
self, text: str
|
||
) -> list[tuple[str, str, bool]]:
|
||
"""Find all invoke blocks inside the function_calls region.
|
||
|
||
Returns a list of ``(func_name, inner_text, is_complete)``
|
||
tuples. *inner_text* is everything between the invoke open
|
||
tag and the close tag (or the end of available text for the
|
||
last, potentially incomplete, invoke).
|
||
"""
|
||
results: list[tuple[str, str, bool]] = []
|
||
|
||
fc_start = text.find(self.tool_call_start_token)
|
||
if fc_start == -1:
|
||
return results
|
||
|
||
region_start = fc_start + len(self.tool_call_start_token)
|
||
fc_end = text.find(self.tool_call_end_token, region_start)
|
||
region = text[region_start:fc_end] if fc_end != -1 else text[region_start:]
|
||
|
||
pos = 0
|
||
while pos < len(region):
|
||
inv_match = self.invoke_start_regex.search(region, pos)
|
||
if not inv_match:
|
||
break
|
||
|
||
func_name = inv_match.group(1)
|
||
body_start = inv_match.end()
|
||
|
||
inv_end_pos = region.find(self.invoke_end_token, body_start)
|
||
if inv_end_pos != -1:
|
||
# Complete invoke block.
|
||
body = region[body_start:inv_end_pos]
|
||
results.append((func_name, body, True))
|
||
pos = inv_end_pos + len(self.invoke_end_token)
|
||
else:
|
||
# Incomplete — still being generated.
|
||
body = region[body_start:]
|
||
overlap = partial_tag_overlap(body, self.invoke_end_token)
|
||
if overlap:
|
||
body = body[:-overlap]
|
||
results.append((func_name, body, False))
|
||
break
|
||
|
||
return results
|
||
|
||
def _build_args_json_so_far(
|
||
self,
|
||
func_name: str,
|
||
inner_text: str,
|
||
is_complete: bool,
|
||
) -> str:
|
||
"""Build a JSON arguments string from the parameters found so far.
|
||
|
||
Handles both fully-closed ``<|DSML|parameter>`` tags and the
|
||
single trailing partial parameter whose value is still being
|
||
streamed.
|
||
"""
|
||
# ---- Collect all fully-closed parameters ----
|
||
complete_params = self.parameter_complete_regex.findall(inner_text)
|
||
parts: list[str] = []
|
||
|
||
for param_name, string_attr, param_value in complete_params:
|
||
key_json = json.dumps(param_name, ensure_ascii=False)
|
||
if string_attr == "true":
|
||
val_json = json.dumps(param_value, ensure_ascii=False)
|
||
else:
|
||
converted = self._convert_with_schema(
|
||
func_name, param_name, param_value
|
||
)
|
||
val_json = json.dumps(converted, ensure_ascii=False)
|
||
parts.append(f"{key_json}: {val_json}")
|
||
|
||
# ---- Handle a trailing partial parameter ----
|
||
last_param_open = inner_text.rfind("<|DSML|parameter")
|
||
last_param_close = inner_text.rfind(self.param_end_token)
|
||
has_partial = last_param_open != -1 and (
|
||
last_param_close == -1 or last_param_close < last_param_open
|
||
)
|
||
|
||
if has_partial:
|
||
partial_text = inner_text[last_param_open:]
|
||
header_match = self.parameter_header_regex.search(partial_text)
|
||
|
||
if header_match:
|
||
param_name = header_match.group(1)
|
||
string_attr = header_match.group(2)
|
||
partial_value = partial_text[header_match.end():]
|
||
|
||
# Strip any bytes that might be the beginning of the
|
||
# closing </|DSML|parameter> tag.
|
||
overlap = partial_tag_overlap(
|
||
partial_value, self.param_end_token
|
||
)
|
||
if overlap:
|
||
partial_value = partial_value[:-overlap]
|
||
|
||
key_json = json.dumps(param_name, ensure_ascii=False)
|
||
|
||
if is_complete:
|
||
# Invoke is closed — treat whatever we have as final.
|
||
if string_attr == "true":
|
||
val_json = json.dumps(
|
||
partial_value, ensure_ascii=False
|
||
)
|
||
else:
|
||
converted = self._convert_with_schema(
|
||
func_name, param_name, partial_value
|
||
)
|
||
val_json = json.dumps(converted, ensure_ascii=False)
|
||
parts.append(f"{key_json}: {val_json}")
|
||
elif string_attr == "true" or self._is_string_type(
|
||
func_name, param_name
|
||
):
|
||
# Stream as an open JSON string (no closing quote).
|
||
escaped = self._json_escape_string_content(partial_value)
|
||
parts.append(f'{key_json}: "{escaped}')
|
||
else:
|
||
# Non-string — emit raw partial value.
|
||
parts.append(f"{key_json}: {partial_value}")
|
||
|
||
# ---- Assemble ----
|
||
if not parts:
|
||
return "{}" if is_complete else ""
|
||
|
||
joined = "{" + ", ".join(parts)
|
||
if is_complete:
|
||
joined += "}"
|
||
return joined
|
||
|
||
def _compute_args_diff(self, index: int, args_so_far: str) -> str | None:
|
||
"""Return only the characters in *args_so_far* that haven't been
|
||
sent yet, or ``None`` if there's nothing new."""
|
||
prev = self.streamed_args_for_tool[index]
|
||
if not args_so_far or len(args_so_far) <= len(prev):
|
||
return None
|
||
diff = args_so_far[len(prev):]
|
||
self.streamed_args_for_tool[index] = args_so_far
|
||
self.prev_tool_call_arr[index]["arguments"] = args_so_far
|
||
return diff
|
||
|
||
def _ensure_tool_state_for(self, index: int) -> None:
|
||
"""Grow the streaming-state arrays so *index* is valid."""
|
||
while len(self._tool_call_ids) <= index:
|
||
self._tool_call_ids.append(self._generate_tool_call_id())
|
||
while len(self.streamed_args_for_tool) <= index:
|
||
self.streamed_args_for_tool.append("")
|
||
while len(self.prev_tool_call_arr) <= index:
|
||
self.prev_tool_call_arr.append({})
|
||
|
||
# ------------------------------------------------------------------
|
||
# Main streaming entry point
|
||
# ------------------------------------------------------------------
|
||
|
||
def extract_tool_calls_streaming(
|
||
self,
|
||
previous_text: str,
|
||
current_text: str,
|
||
delta_text: str,
|
||
previous_token_ids: Sequence[int],
|
||
current_token_ids: Sequence[int],
|
||
delta_token_ids: Sequence[int],
|
||
request: ChatCompletionRequest,
|
||
) -> DeltaMessage | None:
|
||
"""Extract tool calls from streaming output using re-parse-and-diff.
|
||
|
||
On every call we:
|
||
1. Re-scan *current_text* for content outside tool-call regions.
|
||
2. Find all ``<|DSML|invoke>`` regions (complete + partial).
|
||
3. Build JSON args for each, diff against previous, emit deltas.
|
||
|
||
Because the entire text is re-parsed each time, the result is
|
||
correct regardless of how many tokens arrived in this step.
|
||
"""
|
||
# First chunk of a new stream — reset state.
|
||
if not previous_text:
|
||
self._reset_streaming_state()
|
||
|
||
# If tools aren't enabled, just forward content.
|
||
if not self._tools_enabled(request):
|
||
return DeltaMessage(content=delta_text) if delta_text else None
|
||
|
||
# 1. Extract any content outside tool-call regions.
|
||
content = self._extract_content(current_text)
|
||
|
||
# 2. Find all invoke regions.
|
||
regions = self._extract_invoke_regions(current_text)
|
||
tool_call_deltas: list[DeltaToolCall] = []
|
||
|
||
for i, (func_name, inner_text, is_complete) in enumerate(regions):
|
||
self._ensure_tool_state_for(i)
|
||
|
||
# Emit the tool name (once per tool call).
|
||
if "name" not in self.prev_tool_call_arr[i]:
|
||
self.prev_tool_call_arr[i]["name"] = func_name
|
||
tool_call_deltas.append(
|
||
DeltaToolCall(
|
||
index=i,
|
||
id=self._tool_call_ids[i],
|
||
type="function",
|
||
function=DeltaFunctionCall(
|
||
name=func_name,
|
||
arguments="",
|
||
),
|
||
)
|
||
)
|
||
|
||
# Build the JSON args so far and emit the diff.
|
||
args_so_far = self._build_args_json_so_far(
|
||
func_name, inner_text, is_complete
|
||
)
|
||
diff = self._compute_args_diff(i, args_so_far)
|
||
if diff:
|
||
tool_call_deltas.append(
|
||
DeltaToolCall(
|
||
index=i,
|
||
function=DeltaFunctionCall(arguments=diff),
|
||
)
|
||
)
|
||
|
||
if regions:
|
||
self.current_tool_id = len(regions) - 1
|
||
|
||
# 3. Return a delta if we have content or tool-call updates.
|
||
if content or tool_call_deltas:
|
||
return DeltaMessage(
|
||
content=content,
|
||
tool_calls=tool_call_deltas,
|
||
)
|
||
|
||
# Empty delta with token ids means EOS or closing tag — return
|
||
# non-None so the serving framework can finalize finish_reason.
|
||
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
|
||
return DeltaMessage(content="")
|
||
|
||
return None |