2026-04-13 23:42:31 +00:00
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
"""
|
|
|
|
|
|
DeepSeek-V3.2 Tool Call Parser — re-parse-and-diff version.
|
|
|
|
|
|
|
|
|
|
|
|
Adapted from the GLM-4 streaming fix to make the streaming path robust
|
|
|
|
|
|
against multi-token deltas produced by MTP speculative decoding.
|
|
|
|
|
|
|
|
|
|
|
|
Instead of maintaining incremental state that advances one token at a
|
|
|
|
|
|
time, the streaming path re-parses the *entire* current_text on every
|
|
|
|
|
|
call, finds all <|DSML|invoke> regions (complete and in-progress),
|
|
|
|
|
|
builds a JSON arguments string for each, and diffs against what was
|
|
|
|
|
|
previously sent. This makes the parser agnostic to how many tokens
|
|
|
|
|
|
arrive per step.
|
|
|
|
|
|
|
|
|
|
|
|
Key changes vs. the upstream buffer-until-complete parser:
|
|
|
|
|
|
1. _extract_content() handles partial tag overlaps so content text
|
|
|
|
|
|
is never swallowed or duplicated when a tag boundary lands inside
|
|
|
|
|
|
a multi-token chunk.
|
|
|
|
|
|
2. _extract_invoke_regions() finds both complete and incomplete
|
|
|
|
|
|
invoke blocks, enabling streaming of partial arguments.
|
|
|
|
|
|
3. _build_args_json_so_far() constructs the JSON arguments string
|
|
|
|
|
|
incrementally from complete + partial <|DSML|parameter> tags.
|
|
|
|
|
|
4. _compute_args_diff() emits only the newly-added characters.
|
|
|
|
|
|
|
|
|
|
|
|
Drop-in replacement: same class name, same interface.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
from collections.abc import Sequence
|
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
|
|
import regex as re
|
|
|
|
|
|
|
|
|
|
|
|
from vllm.entrypoints.openai.chat_completion.protocol import (
|
|
|
|
|
|
ChatCompletionRequest,
|
|
|
|
|
|
)
|
|
|
|
|
|
from vllm.entrypoints.openai.engine.protocol import (
|
|
|
|
|
|
DeltaFunctionCall,
|
|
|
|
|
|
DeltaMessage,
|
|
|
|
|
|
DeltaToolCall,
|
|
|
|
|
|
ExtractedToolCallInformation,
|
|
|
|
|
|
FunctionCall,
|
|
|
|
|
|
ToolCall,
|
|
|
|
|
|
)
|
|
|
|
|
|
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
|
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
|
|
from vllm.tokenizers import TokenizerLike
|
|
|
|
|
|
from vllm.tool_parsers.abstract_tool_parser import (
|
|
|
|
|
|
Tool,
|
|
|
|
|
|
ToolParser,
|
|
|
|
|
|
)
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-14 00:51:11 +00:00
|
|
|
|
def partial_tag_overlap(text: str, tag: str) -> int:
|
|
|
|
|
|
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
|
|
|
|
|
|
|
|
|
|
|
|
E.g. text ending in ``"<tool_"`` returns 6 when tag is ``"<tool_call>"``.
|
|
|
|
|
|
Returns 0 when there is no overlap.
|
|
|
|
|
|
"""
|
|
|
|
|
|
max_check = min(len(tag) - 1, len(text))
|
|
|
|
|
|
for k in range(max_check, 0, -1):
|
|
|
|
|
|
if text.endswith(tag[:k]):
|
|
|
|
|
|
return k
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-13 23:42:31 +00:00
|
|
|
|
class DeepSeekV32ToolParser(ToolParser):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Re-parse-and-diff tool parser for DeepSeek-V3.2 DSML format.
|
|
|
|
|
|
|
|
|
|
|
|
On every streaming call the parser re-parses ``current_text`` to
|
|
|
|
|
|
find ``<|DSML|invoke>`` regions, builds the JSON arguments string
|
|
|
|
|
|
for each tool call, and diffs against what was previously sent to
|
|
|
|
|
|
emit only new content. This is robust against multi-token deltas
|
|
|
|
|
|
from MTP / EAGLE speculative decoding.
|
|
|
|
|
|
|
|
|
|
|
|
Example tool call format::
|
|
|
|
|
|
|
|
|
|
|
|
<|DSML|function_calls>
|
|
|
|
|
|
<|DSML|invoke name="get_weather">
|
|
|
|
|
|
<|DSML|parameter name="location" string="true">杭州</|DSML|parameter>
|
|
|
|
|
|
<|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter>
|
|
|
|
|
|
</|DSML|invoke>
|
|
|
|
|
|
</|DSML|function_calls>
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
|
|
|
|
|
super().__init__(tokenizer, tools)
|
|
|
|
|
|
|
|
|
|
|
|
# ----- Tag constants -----
|
|
|
|
|
|
self.tool_call_start_token: str = "<|DSML|function_calls>"
|
|
|
|
|
|
self.tool_call_end_token: str = "</|DSML|function_calls>"
|
|
|
|
|
|
self.invoke_end_token: str = "</|DSML|invoke>"
|
|
|
|
|
|
self.param_end_token: str = "</|DSML|parameter>"
|
|
|
|
|
|
|
|
|
|
|
|
# Alias expected by ToolParser base / adjust_request
|
|
|
|
|
|
self.tool_calls_start_token = self.tool_call_start_token
|
|
|
|
|
|
|
|
|
|
|
|
# ----- Compiled regexes -----
|
|
|
|
|
|
# Matches a complete <|DSML|function_calls>…</|DSML|function_calls>
|
|
|
|
|
|
self.tool_call_complete_regex = re.compile(
|
|
|
|
|
|
r"<|DSML|function_calls>(.*?)</|DSML|function_calls>", re.DOTALL
|
|
|
|
|
|
)
|
|
|
|
|
|
# Opening tag of an invoke block — captures the function name.
|
|
|
|
|
|
self.invoke_start_regex = re.compile(
|
|
|
|
|
|
r'<|DSML|invoke\s+name="([^"]+)"\s*>', re.DOTALL
|
|
|
|
|
|
)
|
|
|
|
|
|
# Complete invoke block.
|
|
|
|
|
|
self.invoke_complete_regex = re.compile(
|
|
|
|
|
|
r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>',
|
|
|
|
|
|
re.DOTALL,
|
|
|
|
|
|
)
|
|
|
|
|
|
# Complete parameter tag — captures (name, string_attr, value).
|
|
|
|
|
|
self.parameter_complete_regex = re.compile(
|
|
|
|
|
|
r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>'
|
|
|
|
|
|
r"(.*?)"
|
|
|
|
|
|
r"</|DSML|parameter>",
|
|
|
|
|
|
re.DOTALL,
|
|
|
|
|
|
)
|
|
|
|
|
|
# Just the opening header of a parameter tag (for partial params).
|
|
|
|
|
|
self.parameter_header_regex = re.compile(
|
|
|
|
|
|
r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)"\s*>',
|
|
|
|
|
|
re.DOTALL,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# ----- Streaming state (reset per request) -----
|
|
|
|
|
|
self._sent_content_idx: int = 0
|
|
|
|
|
|
self._tool_call_ids: list[str] = []
|
|
|
|
|
|
self.streamed_args_for_tool: list[str] = []
|
|
|
|
|
|
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
|
|
|
|
|
self.current_tool_id: int = -1
|
|
|
|
|
|
|
|
|
|
|
|
if not self.model_tokenizer:
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
"The model tokenizer must be passed to the ToolParser "
|
|
|
|
|
|
"constructor during construction."
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(
|
|
|
|
|
|
"Successfully initialized %s", self.__class__.__name__
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Request adjustment
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
def adjust_request(
|
|
|
|
|
|
self, request: ChatCompletionRequest | ResponsesRequest
|
|
|
|
|
|
) -> ChatCompletionRequest | ResponsesRequest:
|
|
|
|
|
|
request = super().adjust_request(request)
|
|
|
|
|
|
if request.tools and request.tool_choice != "none":
|
|
|
|
|
|
# Ensure DSML tokens are not stripped during decoding.
|
|
|
|
|
|
request.skip_special_tokens = False
|
|
|
|
|
|
return request
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Static / utility helpers
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _tools_enabled(request: ChatCompletionRequest) -> bool:
|
|
|
|
|
|
"""Check whether tool calling is active for this request."""
|
|
|
|
|
|
try:
|
|
|
|
|
|
tools = getattr(request, "tools", None)
|
|
|
|
|
|
tool_choice = getattr(request, "tool_choice", None)
|
|
|
|
|
|
return bool(tools) and tool_choice != "none"
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
logger.exception("Failed to determine if tools are enabled.")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_tool_call_id(self) -> str:
|
|
|
|
|
|
return f"call_{uuid.uuid4().hex[:24]}"
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _json_escape_string_content(s: str) -> str:
|
|
|
|
|
|
"""JSON-escape a string value (without surrounding quotes)."""
|
|
|
|
|
|
if not s:
|
|
|
|
|
|
return ""
|
|
|
|
|
|
return json.dumps(s, ensure_ascii=False)[1:-1]
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Type conversion helpers
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_param_value_checked(self, value: str, param_type: str) -> Any:
|
|
|
|
|
|
"""Convert a raw string value to the type indicated by *param_type*.
|
|
|
|
|
|
|
|
|
|
|
|
Raises on failure so the caller can try the next candidate type.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if value.lower() == "null":
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
param_type = param_type.lower()
|
|
|
|
|
|
if param_type in ("string", "str", "text"):
|
|
|
|
|
|
return value
|
|
|
|
|
|
elif param_type in ("integer", "int"):
|
|
|
|
|
|
return int(value)
|
|
|
|
|
|
elif param_type in ("number", "float"):
|
|
|
|
|
|
val = float(value)
|
|
|
|
|
|
return val if val != int(val) else int(val)
|
|
|
|
|
|
elif param_type in ("boolean", "bool"):
|
|
|
|
|
|
normed = value.strip().lower()
|
|
|
|
|
|
if normed not in ("false", "0", "true", "1"):
|
|
|
|
|
|
raise ValueError(f"Invalid boolean value: {value!r}")
|
|
|
|
|
|
return normed in ("true", "1")
|
|
|
|
|
|
elif param_type in ("object", "array"):
|
|
|
|
|
|
return json.loads(value)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return json.loads(value)
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any:
|
|
|
|
|
|
"""Try each candidate type in turn; fall back to the raw string."""
|
|
|
|
|
|
if not isinstance(param_type, list):
|
|
|
|
|
|
param_type = [param_type]
|
|
|
|
|
|
for current_type in param_type:
|
|
|
|
|
|
try:
|
|
|
|
|
|
return self._convert_param_value_checked(value, current_type)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
continue
|
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
def _get_param_schema_type(
|
|
|
|
|
|
self, func_name: str, param_name: str
|
|
|
|
|
|
) -> str | list[str]:
|
|
|
|
|
|
"""Look up the JSON-schema type for a parameter, defaulting to
|
|
|
|
|
|
``"string"``."""
|
|
|
|
|
|
if self.tools:
|
|
|
|
|
|
for tool in self.tools:
|
|
|
|
|
|
if (
|
|
|
|
|
|
hasattr(tool, "function")
|
|
|
|
|
|
and tool.function.name == func_name
|
|
|
|
|
|
and hasattr(tool.function, "parameters")
|
|
|
|
|
|
):
|
|
|
|
|
|
schema = tool.function.parameters
|
|
|
|
|
|
if isinstance(schema, dict) and "properties" in schema:
|
|
|
|
|
|
prop = schema["properties"].get(param_name, {})
|
|
|
|
|
|
if isinstance(prop, dict):
|
|
|
|
|
|
return prop.get("type", "string")
|
|
|
|
|
|
break
|
|
|
|
|
|
return "string"
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_with_schema(
|
|
|
|
|
|
self, func_name: str, param_name: str, value: str
|
|
|
|
|
|
) -> Any:
|
|
|
|
|
|
"""Convert *value* using the tool schema for *func_name*.*param_name*."""
|
|
|
|
|
|
param_type = self._get_param_schema_type(func_name, param_name)
|
|
|
|
|
|
return self._convert_param_value(value, param_type)
|
|
|
|
|
|
|
|
|
|
|
|
def _is_string_type(self, func_name: str, param_name: str) -> bool:
|
|
|
|
|
|
"""Return True if the schema says this parameter is a string."""
|
|
|
|
|
|
ptype = self._get_param_schema_type(func_name, param_name)
|
|
|
|
|
|
if isinstance(ptype, list):
|
|
|
|
|
|
return "string" in ptype
|
|
|
|
|
|
return ptype in ("string", "str", "text")
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Non-streaming extraction (unchanged logic, shared helpers)
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
def extract_tool_calls(
|
|
|
|
|
|
self,
|
|
|
|
|
|
model_output: str,
|
|
|
|
|
|
request: ChatCompletionRequest,
|
|
|
|
|
|
) -> ExtractedToolCallInformation:
|
|
|
|
|
|
"""Extract tool calls from complete model output (non-streaming)."""
|
|
|
|
|
|
if self.tool_call_start_token not in model_output:
|
|
|
|
|
|
return ExtractedToolCallInformation(
|
|
|
|
|
|
tools_called=False, tool_calls=[], content=model_output
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
tool_calls: list[ToolCall] = []
|
|
|
|
|
|
|
|
|
|
|
|
for fc_block in self.tool_call_complete_regex.findall(model_output):
|
|
|
|
|
|
for invoke_name, invoke_body in self.invoke_complete_regex.findall(
|
|
|
|
|
|
fc_block
|
|
|
|
|
|
):
|
|
|
|
|
|
# Parse all parameters in this invoke.
|
|
|
|
|
|
raw_params: dict[str, str] = {}
|
|
|
|
|
|
for pname, _str_attr, pval in (
|
|
|
|
|
|
self.parameter_complete_regex.findall(invoke_body)
|
|
|
|
|
|
):
|
|
|
|
|
|
raw_params[pname] = pval
|
|
|
|
|
|
|
|
|
|
|
|
# Convert types via schema.
|
|
|
|
|
|
converted: dict[str, Any] = {}
|
|
|
|
|
|
for pname, pval in raw_params.items():
|
|
|
|
|
|
converted[pname] = self._convert_with_schema(
|
|
|
|
|
|
invoke_name, pname, pval
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
tool_calls.append(
|
|
|
|
|
|
ToolCall(
|
|
|
|
|
|
type="function",
|
|
|
|
|
|
function=FunctionCall(
|
|
|
|
|
|
name=invoke_name,
|
|
|
|
|
|
arguments=json.dumps(
|
|
|
|
|
|
converted, ensure_ascii=False
|
|
|
|
|
|
),
|
|
|
|
|
|
),
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not tool_calls:
|
|
|
|
|
|
return ExtractedToolCallInformation(
|
|
|
|
|
|
tools_called=False, tool_calls=[], content=model_output
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
first_idx = model_output.find(self.tool_call_start_token)
|
|
|
|
|
|
content = model_output[:first_idx] if first_idx > 0 else None
|
|
|
|
|
|
|
|
|
|
|
|
return ExtractedToolCallInformation(
|
|
|
|
|
|
tools_called=True, tool_calls=tool_calls, content=content
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
logger.exception("Error extracting tool calls from complete output")
|
|
|
|
|
|
return ExtractedToolCallInformation(
|
|
|
|
|
|
tools_called=False, tool_calls=[], content=model_output
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Streaming helpers — re-parse-and-diff
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
def _reset_streaming_state(self) -> None:
|
|
|
|
|
|
self._sent_content_idx = 0
|
|
|
|
|
|
self._tool_call_ids.clear()
|
|
|
|
|
|
self.streamed_args_for_tool.clear()
|
|
|
|
|
|
self.prev_tool_call_arr.clear()
|
|
|
|
|
|
self.current_tool_id = -1
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_content(self, current_text: str) -> str | None:
|
|
|
|
|
|
"""Return any non-tool-call text that hasn't been sent yet.
|
|
|
|
|
|
|
|
|
|
|
|
Walks *current_text* from ``_sent_content_idx``, collecting text
|
|
|
|
|
|
outside ``<|DSML|function_calls>`` regions. Uses
|
|
|
|
|
|
``partial_tag_overlap`` to avoid emitting bytes that might turn
|
|
|
|
|
|
out to be the start of the function-calls tag once the next
|
|
|
|
|
|
chunk arrives.
|
|
|
|
|
|
"""
|
|
|
|
|
|
content_segments: list[str] = []
|
|
|
|
|
|
pos = self._sent_content_idx
|
|
|
|
|
|
|
|
|
|
|
|
while pos < len(current_text):
|
|
|
|
|
|
start = current_text.find(self.tool_call_start_token, pos)
|
|
|
|
|
|
if start == -1:
|
|
|
|
|
|
# No (more) tool-call regions — send the tail, minus
|
|
|
|
|
|
# any suffix that could be the beginning of the tag.
|
|
|
|
|
|
tail = current_text[pos:]
|
|
|
|
|
|
overlap = partial_tag_overlap(tail, self.tool_call_start_token)
|
|
|
|
|
|
sendable = tail[: len(tail) - overlap] if overlap else tail
|
|
|
|
|
|
if sendable:
|
|
|
|
|
|
content_segments.append(sendable)
|
|
|
|
|
|
pos = len(current_text) - overlap
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
# Text between previous position and the tag start is content.
|
|
|
|
|
|
if start > pos:
|
|
|
|
|
|
content_segments.append(current_text[pos:start])
|
|
|
|
|
|
|
|
|
|
|
|
# Skip past the tool-call region.
|
|
|
|
|
|
end = current_text.find(self.tool_call_end_token, start)
|
|
|
|
|
|
if end != -1:
|
|
|
|
|
|
pos = end + len(self.tool_call_end_token)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Region still open — park cursor at start, stop.
|
|
|
|
|
|
pos = start
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
if content_segments:
|
|
|
|
|
|
self._sent_content_idx = pos
|
|
|
|
|
|
return "".join(content_segments)
|
|
|
|
|
|
if pos > self._sent_content_idx:
|
|
|
|
|
|
self._sent_content_idx = pos
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_invoke_regions(
|
|
|
|
|
|
self, text: str
|
|
|
|
|
|
) -> list[tuple[str, str, bool]]:
|
|
|
|
|
|
"""Find all invoke blocks inside the function_calls region.
|
|
|
|
|
|
|
|
|
|
|
|
Returns a list of ``(func_name, inner_text, is_complete)``
|
|
|
|
|
|
tuples. *inner_text* is everything between the invoke open
|
|
|
|
|
|
tag and the close tag (or the end of available text for the
|
|
|
|
|
|
last, potentially incomplete, invoke).
|
|
|
|
|
|
"""
|
|
|
|
|
|
results: list[tuple[str, str, bool]] = []
|
|
|
|
|
|
|
|
|
|
|
|
fc_start = text.find(self.tool_call_start_token)
|
|
|
|
|
|
if fc_start == -1:
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
region_start = fc_start + len(self.tool_call_start_token)
|
|
|
|
|
|
fc_end = text.find(self.tool_call_end_token, region_start)
|
|
|
|
|
|
region = text[region_start:fc_end] if fc_end != -1 else text[region_start:]
|
|
|
|
|
|
|
|
|
|
|
|
pos = 0
|
|
|
|
|
|
while pos < len(region):
|
|
|
|
|
|
inv_match = self.invoke_start_regex.search(region, pos)
|
|
|
|
|
|
if not inv_match:
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
func_name = inv_match.group(1)
|
|
|
|
|
|
body_start = inv_match.end()
|
|
|
|
|
|
|
|
|
|
|
|
inv_end_pos = region.find(self.invoke_end_token, body_start)
|
|
|
|
|
|
if inv_end_pos != -1:
|
|
|
|
|
|
# Complete invoke block.
|
|
|
|
|
|
body = region[body_start:inv_end_pos]
|
|
|
|
|
|
results.append((func_name, body, True))
|
|
|
|
|
|
pos = inv_end_pos + len(self.invoke_end_token)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Incomplete — still being generated.
|
|
|
|
|
|
body = region[body_start:]
|
|
|
|
|
|
overlap = partial_tag_overlap(body, self.invoke_end_token)
|
|
|
|
|
|
if overlap:
|
|
|
|
|
|
body = body[:-overlap]
|
|
|
|
|
|
results.append((func_name, body, False))
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def _build_args_json_so_far(
|
|
|
|
|
|
self,
|
|
|
|
|
|
func_name: str,
|
|
|
|
|
|
inner_text: str,
|
|
|
|
|
|
is_complete: bool,
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""Build a JSON arguments string from the parameters found so far.
|
|
|
|
|
|
|
|
|
|
|
|
Handles both fully-closed ``<|DSML|parameter>`` tags and the
|
|
|
|
|
|
single trailing partial parameter whose value is still being
|
|
|
|
|
|
streamed.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# ---- Collect all fully-closed parameters ----
|
|
|
|
|
|
complete_params = self.parameter_complete_regex.findall(inner_text)
|
|
|
|
|
|
parts: list[str] = []
|
|
|
|
|
|
|
|
|
|
|
|
for param_name, string_attr, param_value in complete_params:
|
|
|
|
|
|
key_json = json.dumps(param_name, ensure_ascii=False)
|
|
|
|
|
|
if string_attr == "true":
|
|
|
|
|
|
val_json = json.dumps(param_value, ensure_ascii=False)
|
|
|
|
|
|
else:
|
|
|
|
|
|
converted = self._convert_with_schema(
|
|
|
|
|
|
func_name, param_name, param_value
|
|
|
|
|
|
)
|
|
|
|
|
|
val_json = json.dumps(converted, ensure_ascii=False)
|
|
|
|
|
|
parts.append(f"{key_json}: {val_json}")
|
|
|
|
|
|
|
|
|
|
|
|
# ---- Handle a trailing partial parameter ----
|
|
|
|
|
|
last_param_open = inner_text.rfind("<|DSML|parameter")
|
|
|
|
|
|
last_param_close = inner_text.rfind(self.param_end_token)
|
|
|
|
|
|
has_partial = last_param_open != -1 and (
|
|
|
|
|
|
last_param_close == -1 or last_param_close < last_param_open
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if has_partial:
|
|
|
|
|
|
partial_text = inner_text[last_param_open:]
|
|
|
|
|
|
header_match = self.parameter_header_regex.search(partial_text)
|
|
|
|
|
|
|
|
|
|
|
|
if header_match:
|
|
|
|
|
|
param_name = header_match.group(1)
|
|
|
|
|
|
string_attr = header_match.group(2)
|
|
|
|
|
|
partial_value = partial_text[header_match.end():]
|
|
|
|
|
|
|
|
|
|
|
|
# Strip any bytes that might be the beginning of the
|
|
|
|
|
|
# closing </|DSML|parameter> tag.
|
|
|
|
|
|
overlap = partial_tag_overlap(
|
|
|
|
|
|
partial_value, self.param_end_token
|
|
|
|
|
|
)
|
|
|
|
|
|
if overlap:
|
|
|
|
|
|
partial_value = partial_value[:-overlap]
|
|
|
|
|
|
|
|
|
|
|
|
key_json = json.dumps(param_name, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
if is_complete:
|
|
|
|
|
|
# Invoke is closed — treat whatever we have as final.
|
|
|
|
|
|
if string_attr == "true":
|
|
|
|
|
|
val_json = json.dumps(
|
|
|
|
|
|
partial_value, ensure_ascii=False
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
converted = self._convert_with_schema(
|
|
|
|
|
|
func_name, param_name, partial_value
|
|
|
|
|
|
)
|
|
|
|
|
|
val_json = json.dumps(converted, ensure_ascii=False)
|
|
|
|
|
|
parts.append(f"{key_json}: {val_json}")
|
|
|
|
|
|
elif string_attr == "true" or self._is_string_type(
|
|
|
|
|
|
func_name, param_name
|
|
|
|
|
|
):
|
|
|
|
|
|
# Stream as an open JSON string (no closing quote).
|
|
|
|
|
|
escaped = self._json_escape_string_content(partial_value)
|
|
|
|
|
|
parts.append(f'{key_json}: "{escaped}')
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Non-string — emit raw partial value.
|
|
|
|
|
|
parts.append(f"{key_json}: {partial_value}")
|
|
|
|
|
|
|
|
|
|
|
|
# ---- Assemble ----
|
|
|
|
|
|
if not parts:
|
|
|
|
|
|
return "{}" if is_complete else ""
|
|
|
|
|
|
|
|
|
|
|
|
joined = "{" + ", ".join(parts)
|
|
|
|
|
|
if is_complete:
|
|
|
|
|
|
joined += "}"
|
|
|
|
|
|
return joined
|
|
|
|
|
|
|
|
|
|
|
|
def _compute_args_diff(self, index: int, args_so_far: str) -> str | None:
|
|
|
|
|
|
"""Return only the characters in *args_so_far* that haven't been
|
|
|
|
|
|
sent yet, or ``None`` if there's nothing new."""
|
|
|
|
|
|
prev = self.streamed_args_for_tool[index]
|
|
|
|
|
|
if not args_so_far or len(args_so_far) <= len(prev):
|
|
|
|
|
|
return None
|
|
|
|
|
|
diff = args_so_far[len(prev):]
|
|
|
|
|
|
self.streamed_args_for_tool[index] = args_so_far
|
|
|
|
|
|
self.prev_tool_call_arr[index]["arguments"] = args_so_far
|
|
|
|
|
|
return diff
|
|
|
|
|
|
|
|
|
|
|
|
def _ensure_tool_state_for(self, index: int) -> None:
|
|
|
|
|
|
"""Grow the streaming-state arrays so *index* is valid."""
|
|
|
|
|
|
while len(self._tool_call_ids) <= index:
|
|
|
|
|
|
self._tool_call_ids.append(self._generate_tool_call_id())
|
|
|
|
|
|
while len(self.streamed_args_for_tool) <= index:
|
|
|
|
|
|
self.streamed_args_for_tool.append("")
|
|
|
|
|
|
while len(self.prev_tool_call_arr) <= index:
|
|
|
|
|
|
self.prev_tool_call_arr.append({})
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Main streaming entry point
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
def extract_tool_calls_streaming(
|
|
|
|
|
|
self,
|
|
|
|
|
|
previous_text: str,
|
|
|
|
|
|
current_text: str,
|
|
|
|
|
|
delta_text: str,
|
|
|
|
|
|
previous_token_ids: Sequence[int],
|
|
|
|
|
|
current_token_ids: Sequence[int],
|
|
|
|
|
|
delta_token_ids: Sequence[int],
|
|
|
|
|
|
request: ChatCompletionRequest,
|
|
|
|
|
|
) -> DeltaMessage | None:
|
|
|
|
|
|
"""Extract tool calls from streaming output using re-parse-and-diff.
|
|
|
|
|
|
|
|
|
|
|
|
On every call we:
|
|
|
|
|
|
1. Re-scan *current_text* for content outside tool-call regions.
|
|
|
|
|
|
2. Find all ``<|DSML|invoke>`` regions (complete + partial).
|
|
|
|
|
|
3. Build JSON args for each, diff against previous, emit deltas.
|
|
|
|
|
|
|
|
|
|
|
|
Because the entire text is re-parsed each time, the result is
|
|
|
|
|
|
correct regardless of how many tokens arrived in this step.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# First chunk of a new stream — reset state.
|
|
|
|
|
|
if not previous_text:
|
|
|
|
|
|
self._reset_streaming_state()
|
|
|
|
|
|
|
|
|
|
|
|
# If tools aren't enabled, just forward content.
|
|
|
|
|
|
if not self._tools_enabled(request):
|
|
|
|
|
|
return DeltaMessage(content=delta_text) if delta_text else None
|
|
|
|
|
|
|
|
|
|
|
|
# 1. Extract any content outside tool-call regions.
|
|
|
|
|
|
content = self._extract_content(current_text)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Find all invoke regions.
|
|
|
|
|
|
regions = self._extract_invoke_regions(current_text)
|
|
|
|
|
|
tool_call_deltas: list[DeltaToolCall] = []
|
|
|
|
|
|
|
|
|
|
|
|
for i, (func_name, inner_text, is_complete) in enumerate(regions):
|
|
|
|
|
|
self._ensure_tool_state_for(i)
|
|
|
|
|
|
|
|
|
|
|
|
# Emit the tool name (once per tool call).
|
|
|
|
|
|
if "name" not in self.prev_tool_call_arr[i]:
|
|
|
|
|
|
self.prev_tool_call_arr[i]["name"] = func_name
|
|
|
|
|
|
tool_call_deltas.append(
|
|
|
|
|
|
DeltaToolCall(
|
|
|
|
|
|
index=i,
|
|
|
|
|
|
id=self._tool_call_ids[i],
|
|
|
|
|
|
type="function",
|
|
|
|
|
|
function=DeltaFunctionCall(
|
|
|
|
|
|
name=func_name,
|
|
|
|
|
|
arguments="",
|
|
|
|
|
|
),
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Build the JSON args so far and emit the diff.
|
|
|
|
|
|
args_so_far = self._build_args_json_so_far(
|
|
|
|
|
|
func_name, inner_text, is_complete
|
|
|
|
|
|
)
|
|
|
|
|
|
diff = self._compute_args_diff(i, args_so_far)
|
|
|
|
|
|
if diff:
|
|
|
|
|
|
tool_call_deltas.append(
|
|
|
|
|
|
DeltaToolCall(
|
|
|
|
|
|
index=i,
|
|
|
|
|
|
function=DeltaFunctionCall(arguments=diff),
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if regions:
|
|
|
|
|
|
self.current_tool_id = len(regions) - 1
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Return a delta if we have content or tool-call updates.
|
|
|
|
|
|
if content or tool_call_deltas:
|
|
|
|
|
|
return DeltaMessage(
|
|
|
|
|
|
content=content,
|
|
|
|
|
|
tool_calls=tool_call_deltas,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Empty delta with token ids means EOS or closing tag — return
|
|
|
|
|
|
# non-None so the serving framework can finalize finish_reason.
|
|
|
|
|
|
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
|
|
|
|
|
|
return DeltaMessage(content="")
|
|
|
|
|
|
|
|
|
|
|
|
return None
|