Files
vllm/vllm/tool_parsers/deepseekv32_tool_parser.py

318 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
logger = init_logger(__name__)
class DeepSeekV32ToolParser(ToolParser):
"""
example tool call content:
<DSMLfunction_calls>
<DSMLinvoke name="get_weather">
<DSMLparameter name="location" string="true">杭州</DSMLparameter>
<DSMLparameter name="date" string="true">2024-01-16</DSMLparameter>
</DSMLinvoke>
<DSMLinvoke name="get_weather">
<DSMLparameter name="location" string="true">北京</DSMLparameter>
<DSMLparameter name="date" string="true">2024-01-16</DSMLparameter>
</DSMLinvoke>
</DSMLfunction_calls>
"""
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.prev_tool_call_arr: list[dict] = []
# Sentinel token
self.tool_call_start_token: str = "<DSMLfunction_calls>"
# Streaming state
self.is_tool_call_started: bool = False
self.current_tool_index: int = 0
# Regex patterns for complete parsing
self.tool_call_complete_regex = re.compile(
r"<DSMLfunction_calls>(.*?)</DSMLfunction_calls>", re.DOTALL
)
self.invoke_complete_regex = re.compile(
r'<DSMLinvoke\s+name="([^"]+)"\s*>(.*?)</DSMLinvoke>', re.DOTALL
)
self.parameter_complete_regex = re.compile(
r'<DSMLparameter\s+name="([^"]+)"\s+string="(?:true|false)"\s*>(.*?)</DSMLparameter>',
re.DOTALL,
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
logger.debug(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# Ensure tool call tokens
# (<DSMLfunction_calls>, </DSMLfunction_calls>)
# are not skippedduring decoding.
# Even though they are not marked as special tokens,
# setting skip_special_tokens=False ensures proper handling in
# transformers 5.x where decoding behavior may have changed.
request.skip_special_tokens = False
return request
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _parse_invoke_params(self, invoke_str: str) -> dict:
param_dict = dict()
for param_name, param_val in self.parameter_complete_regex.findall(invoke_str):
param_dict[param_name] = param_val
return param_dict
def _convert_param_value_checked(self, value: str, param_type: str) -> Any:
"""Convert parameter value to the correct type."""
if value.lower() == "null":
return None
param_type = param_type.lower()
if param_type in ["string", "str", "text"]:
return value
elif param_type in ["integer", "int"]:
return int(value)
elif param_type in ["number", "float"]:
val = float(value)
return val if val != int(val) else int(val)
elif param_type in ["boolean", "bool"]:
value = value.strip()
if value.lower() not in ["false", "0", "true", "1"]:
raise ValueError("Invalid boolean value")
return value.lower() in ["true", "1"]
elif param_type in ["object", "array"]:
return json.loads(value)
else:
return json.loads(value)
def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any:
"""Convert parameter value to the correct type."""
if not isinstance(param_type, list):
param_type = [param_type]
for current_type in param_type:
try:
return self._convert_param_value_checked(value, current_type)
except Exception:
continue
# return value as fallback
return value
def _convert_params_with_schema(
self,
function_name: str,
param_dict: dict[str, str],
) -> dict[str, Any]:
"""Convert raw string param values using the tool schema types."""
param_config: dict = {}
if self.tools:
for tool in self.tools:
if (
hasattr(tool, "function")
and tool.function.name == function_name
and hasattr(tool.function, "parameters")
):
schema = tool.function.parameters
if isinstance(schema, dict) and "properties" in schema:
param_config = schema["properties"]
break
converted: dict[str, Any] = {}
for name, value in param_dict.items():
param_type = "string"
if name in param_config and isinstance(param_config[name], dict):
param_type = param_config[name].get("type", "string")
converted[name] = self._convert_param_value(value, param_type)
return converted
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""Extract tool calls from complete model output (non-streaming)."""
# Quick check
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
tool_calls = []
# Find all complete tool_call blocks
for tool_call_match in self.tool_call_complete_regex.findall(model_output):
# Find all invokes within this tool_call
for invoke_name, invoke_content in self.invoke_complete_regex.findall(
tool_call_match
):
param_dict = self._parse_invoke_params(invoke_content)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=invoke_name,
arguments=json.dumps(param_dict, ensure_ascii=False),
),
)
)
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# Extract content before first tool call
first_tool_idx = model_output.find(self.tool_call_start_token)
content = model_output[:first_tool_idx] if first_tool_idx > 0 else None
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content
)
except Exception:
logger.exception("Error extracting tool calls")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.is_tool_call_started = False
self.prev_tool_call_arr.clear()
self.streamed_args_for_tool.clear()
def _extract_delta_tool_calls(
self,
current_text: str,
request: ChatCompletionRequest | None,
) -> list[DeltaToolCall]:
"""Extract DeltaToolCalls from newly completed <invoke> blocks.
Tracks progress via ``current_tool_index`` so each block is
extracted exactly once across successive streaming calls.
"""
complete_invokes = self.invoke_complete_regex.findall(current_text)
delta_tool_calls: list[DeltaToolCall] = []
while len(complete_invokes) > self.current_tool_index:
invoke_name, invoke_body = complete_invokes[self.current_tool_index]
param_dict = self._parse_invoke_params(invoke_body)
converted = self._convert_params_with_schema(invoke_name, param_dict)
args_json = json.dumps(converted, ensure_ascii=False)
idx = self.current_tool_index
self.current_tool_index += 1
self.prev_tool_call_arr.append(
{"name": invoke_name, "arguments": converted}
)
self.streamed_args_for_tool.append(args_json)
delta_tool_calls.append(
DeltaToolCall(
index=idx,
id=self._generate_tool_call_id(),
function=DeltaFunctionCall(
name=invoke_name,
arguments=args_json,
),
type="function",
)
)
return delta_tool_calls
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
current_token_ids: Sequence[int], # pylint: disable=unused-argument
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming model output.
Uses a buffer-until-complete-invoke strategy: tokens are buffered
until a complete invoke block is available, then parsed and emitted
in one shot.
"""
# First chunk of a new stream — reset state from prior request.
if not previous_text:
self._reset_streaming_state()
# Detect whether we've entered the tool-call region.
# Use current_text (not delta_text) since the start token may
# be split across chunks.
content_before = None
if self.is_tool_call_started:
pass
elif self.tool_call_start_token in current_text:
# Tool-call region found, capture any plain text before it.
self.is_tool_call_started = True
start_idx = current_text.index(self.tool_call_start_token)
content_before = current_text[len(previous_text) : start_idx] or None
else:
# Still in plain-text region, forward as content.
return DeltaMessage(content=delta_text) if delta_text else None
# Inside tool-call region: emit any newly completed invokes.
delta_tool_calls = self._extract_delta_tool_calls(current_text, request)
if delta_tool_calls or content_before:
return DeltaMessage(
content=content_before,
tool_calls=delta_tool_calls,
)
# Empty delta with token ids means EOS or closing tag; return
# non-None so the serving framework can finalize finish_reason.
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
return DeltaMessage(content="")
return None