Files
nvfp4-megamoe-kernel/reference/vllm/tool_parsers/deepseekv32_tool_parser.py
biondizzle ca7c309463 Add reference/ dir: vLLM tokenizers, reasoning parsers, tool parsers, official inference
- reference/vllm/tokenizers/ — official DSV4 tokenizer + encoding (read-only)
- reference/vllm/reasoning/ — thinking mode parsers (DeepSeekR1 style )
- reference/vllm/tool_parsers/ — DSML tool call parsers (V3.2 base, V4 variant)
- reference/official_inference/ — original weight's generate.py, model.py, kernel.py
- reference/README.md documents the layout and which files matter for our pipeline
- These are read-only references for cross-checking, not imported by production code
2026-06-03 10:25:23 +00:00

323 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import partial_tag_overlap
logger = init_logger(__name__)
class DeepSeekV32ToolParser(ToolParser):
"""
example tool call content:
<DSMLfunction_calls>
<DSMLinvoke name="get_weather">
<DSMLparameter name="location" string="true">杭州</DSMLparameter>
<DSMLparameter name="date" string="true">2024-01-16</DSMLparameter>
</DSMLinvoke>
<DSMLinvoke name="get_weather">
<DSMLparameter name="location" string="true">北京</DSMLparameter>
<DSMLparameter name="date" string="true">2024-01-16</DSMLparameter>
</DSMLinvoke>
</DSMLfunction_calls>
"""
tool_call_start_token: str = "<DSMLfunction_calls>"
tool_call_end_token: str = "</DSMLfunction_calls>"
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.prev_tool_call_arr: list[dict] = []
# Streaming state
self.current_tool_index: int = 0
self._sent_content_idx: int = 0
# Regex patterns for complete parsing
self.tool_call_complete_regex = re.compile(
re.escape(self.tool_call_start_token)
+ r"(.*?)"
+ re.escape(self.tool_call_end_token),
re.DOTALL,
)
self.invoke_complete_regex = re.compile(
r'<DSMLinvoke\s+name="([^"]+)"\s*>(.*?)</DSMLinvoke>', re.DOTALL
)
self.parameter_complete_regex = re.compile(
r'<DSMLparameter\s+name="([^"]+)"\s+string="(?:true|false)"\s*>(.*?)</DSMLparameter>',
re.DOTALL,
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
logger.debug(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# Ensure tool call tokens
# (e.g. <DSMLfunction_calls>, </DSMLfunction_calls>)
# are not skippedduring decoding.
# Even though they are not marked as special tokens,
# setting skip_special_tokens=False ensures proper handling in
# transformers 5.x where decoding behavior may have changed.
request.skip_special_tokens = False
return request
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _parse_invoke_params(self, invoke_str: str) -> dict:
param_dict = dict()
for param_name, param_val in self.parameter_complete_regex.findall(invoke_str):
param_dict[param_name] = param_val
return param_dict
def _convert_param_value_checked(self, value: str, param_type: str) -> Any:
"""Convert parameter value to the correct type."""
if value.lower() == "null":
return None
param_type = param_type.lower()
if param_type in ["string", "str", "text"]:
return value
elif param_type in ["integer", "int"]:
return int(value)
elif param_type in ["number", "float"]:
val = float(value)
return val if val != int(val) else int(val)
elif param_type in ["boolean", "bool"]:
value = value.strip()
if value.lower() not in ["false", "0", "true", "1"]:
raise ValueError("Invalid boolean value")
return value.lower() in ["true", "1"]
elif param_type in ["object", "array"]:
return json.loads(value)
else:
return json.loads(value)
def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any:
"""Convert parameter value to the correct type."""
if not isinstance(param_type, list):
param_type = [param_type]
for current_type in param_type:
try:
return self._convert_param_value_checked(value, current_type)
except Exception:
continue
# return value as fallback
return value
def _convert_params_with_schema(
self,
function_name: str,
param_dict: dict[str, str],
) -> dict[str, Any]:
"""Convert raw string param values using the tool schema types."""
param_config: dict = {}
if self.tools:
for tool in self.tools:
if (
hasattr(tool, "function")
and tool.function.name == function_name
and hasattr(tool.function, "parameters")
):
schema = tool.function.parameters
if isinstance(schema, dict) and "properties" in schema:
param_config = schema["properties"]
break
converted: dict[str, Any] = {}
for name, value in param_dict.items():
param_type = "string"
if name in param_config and isinstance(param_config[name], dict):
param_type = param_config[name].get("type", "string")
converted[name] = self._convert_param_value(value, param_type)
return converted
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""Extract tool calls from complete model output (non-streaming)."""
# Quick check
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
tool_calls = []
# Find all complete tool_call blocks
for tool_call_match in self.tool_call_complete_regex.findall(model_output):
# Find all invokes within this tool_call
for invoke_name, invoke_content in self.invoke_complete_regex.findall(
tool_call_match
):
param_dict = self._parse_invoke_params(invoke_content)
params = self._convert_params_with_schema(invoke_name, param_dict)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=invoke_name,
arguments=json.dumps(params, ensure_ascii=False),
),
)
)
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# Extract content before first tool call
first_tool_idx = model_output.find(self.tool_call_start_token)
content = model_output[:first_tool_idx] if first_tool_idx > 0 else None
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content
)
except Exception:
logger.exception("Error extracting tool calls")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self._sent_content_idx = 0
self.prev_tool_call_arr.clear()
self.streamed_args_for_tool.clear()
def _extract_delta_tool_calls(
self,
current_text: str,
request: ChatCompletionRequest | None,
) -> list[DeltaToolCall]:
"""Extract DeltaToolCalls from newly completed <invoke> blocks.
Tracks progress via ``current_tool_index`` so each block is
extracted exactly once across successive streaming calls.
"""
complete_invokes = self.invoke_complete_regex.findall(current_text)
delta_tool_calls: list[DeltaToolCall] = []
while len(complete_invokes) > self.current_tool_index:
invoke_name, invoke_body = complete_invokes[self.current_tool_index]
param_dict = self._parse_invoke_params(invoke_body)
converted = self._convert_params_with_schema(invoke_name, param_dict)
args_json = json.dumps(converted, ensure_ascii=False)
idx = self.current_tool_index
self.current_tool_index += 1
self.prev_tool_call_arr.append(
{"name": invoke_name, "arguments": converted}
)
self.streamed_args_for_tool.append(args_json)
delta_tool_calls.append(
DeltaToolCall(
index=idx,
id=self._generate_tool_call_id(),
function=DeltaFunctionCall(
name=invoke_name,
arguments=args_json,
),
type="function",
)
)
return delta_tool_calls
def _extract_content(self, current_text: str) -> str | None:
"""Return unsent non-tool-call text, or None.
Holds back any suffix that could be a partial start marker
so that split markers are never leaked as content.
"""
if self.tool_call_start_token not in current_text:
overlap = partial_tag_overlap(current_text, self.tool_call_start_token)
sendable_idx = len(current_text) - overlap
else:
sendable_idx = current_text.index(self.tool_call_start_token)
if sendable_idx > self._sent_content_idx:
content = current_text[self._sent_content_idx : sendable_idx]
self._sent_content_idx = sendable_idx
return content
return None
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
current_token_ids: Sequence[int], # pylint: disable=unused-argument
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming model output.
Uses a buffer-until-complete-invoke strategy: tokens are buffered
until a complete invoke block is available, then parsed and emitted
in one shot.
"""
# First chunk of a new stream — reset state from prior request.
if not previous_text:
self._reset_streaming_state()
content = self._extract_content(current_text)
delta_tool_calls = self._extract_delta_tool_calls(current_text, request)
if delta_tool_calls or content:
return DeltaMessage(content=content, tool_calls=delta_tool_calls)
# Empty delta with token ids means EOS or closing tag; return
# non-None so the serving framework can finalize finish_reason.
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
return DeltaMessage(content="")
return None