453 lines
16 KiB
Python
453 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import json
|
|
import uuid
|
|
from collections.abc import Sequence
|
|
from typing import Any
|
|
|
|
import regex as re
|
|
|
|
from vllm.entrypoints.openai.chat_completion.protocol import (
|
|
ChatCompletionRequest,
|
|
)
|
|
from vllm.entrypoints.openai.engine.protocol import (
|
|
DeltaFunctionCall,
|
|
DeltaMessage,
|
|
DeltaToolCall,
|
|
ExtractedToolCallInformation,
|
|
FunctionCall,
|
|
ToolCall,
|
|
)
|
|
from vllm.logger import init_logger
|
|
from vllm.tokenizers import TokenizerLike
|
|
from vllm.tool_parsers.abstract_tool_parser import (
|
|
ToolParser,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class MinimaxM2ToolParser(ToolParser):
|
|
def __init__(self, tokenizer: TokenizerLike):
|
|
super().__init__(tokenizer)
|
|
|
|
self.prev_tool_call_arr: list[dict] = []
|
|
|
|
# Sentinel tokens
|
|
self.tool_call_start_token: str = "<minimax:tool_call>"
|
|
self.tool_call_end_token: str = "</minimax:tool_call>"
|
|
|
|
# Streaming state
|
|
self.is_tool_call_started: bool = False
|
|
self.current_tool_index: int = 0
|
|
|
|
# Regex patterns for complete parsing
|
|
self.tool_call_complete_regex = re.compile(
|
|
r"<minimax:tool_call>(.*?)</minimax:tool_call>", re.DOTALL
|
|
)
|
|
self.invoke_complete_regex = re.compile(
|
|
r"<invoke name=(.*?)</invoke>", re.DOTALL
|
|
)
|
|
self.parameter_complete_regex = re.compile(
|
|
r"<parameter name=(.*?)</parameter>", re.DOTALL
|
|
)
|
|
|
|
if not self.model_tokenizer:
|
|
raise ValueError(
|
|
"The model tokenizer must be passed to the ToolParser "
|
|
"constructor during construction."
|
|
)
|
|
|
|
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
|
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
|
|
|
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
|
|
raise RuntimeError(
|
|
"MiniMax M2 Tool parser could not locate tool call start/end "
|
|
"tokens in the tokenizer!"
|
|
)
|
|
|
|
logger.debug(
|
|
"vLLM Successfully import tool parser %s !", self.__class__.__name__
|
|
)
|
|
|
|
def _generate_tool_call_id(self) -> str:
|
|
"""Generate a unique tool call ID."""
|
|
return f"call_{uuid.uuid4().hex[:24]}"
|
|
|
|
def _extract_name(self, name_str: str) -> str:
|
|
"""Extract name from quoted string."""
|
|
name_str = name_str.strip()
|
|
if (name_str.startswith('"') and name_str.endswith('"')) or (
|
|
name_str.startswith("'") and name_str.endswith("'")
|
|
):
|
|
return name_str[1:-1]
|
|
return name_str
|
|
|
|
def _extract_types_from_schema(self, schema: Any) -> list[str]:
|
|
"""
|
|
Extract all possible types from a JSON schema definition.
|
|
Handles anyOf, oneOf, allOf, type arrays, and enum fields.
|
|
|
|
Args:
|
|
schema: The JSON schema definition for a parameter
|
|
|
|
Returns:
|
|
List of type strings (e.g., ["string", "integer", "null"])
|
|
"""
|
|
if schema is None:
|
|
return ["string"]
|
|
|
|
if not isinstance(schema, dict):
|
|
return ["string"]
|
|
|
|
types: set[str] = set()
|
|
|
|
# Handle direct "type" field
|
|
if "type" in schema:
|
|
type_value = schema["type"]
|
|
if isinstance(type_value, str):
|
|
types.add(type_value)
|
|
elif isinstance(type_value, list):
|
|
for t in type_value:
|
|
if isinstance(t, str):
|
|
types.add(t)
|
|
|
|
# Handle enum - infer types from enum values
|
|
if "enum" in schema and isinstance(schema["enum"], list) and schema["enum"]:
|
|
for value in schema["enum"]:
|
|
if value is None:
|
|
types.add("null")
|
|
elif isinstance(value, bool):
|
|
types.add("boolean")
|
|
elif isinstance(value, int):
|
|
types.add("integer")
|
|
elif isinstance(value, float):
|
|
types.add("number")
|
|
elif isinstance(value, str):
|
|
types.add("string")
|
|
elif isinstance(value, list):
|
|
types.add("array")
|
|
elif isinstance(value, dict):
|
|
types.add("object")
|
|
|
|
# Handle anyOf, oneOf, allOf - recursively extract types
|
|
for choice_field in ("anyOf", "oneOf", "allOf"):
|
|
if choice_field in schema and isinstance(schema[choice_field], list):
|
|
for choice in schema[choice_field]:
|
|
extracted = self._extract_types_from_schema(choice)
|
|
types.update(extracted)
|
|
|
|
# If no types found, default to string
|
|
if not types:
|
|
return ["string"]
|
|
|
|
return list(types)
|
|
|
|
def _convert_param_value_with_types(
|
|
self, value: str, param_types: list[str]
|
|
) -> Any:
|
|
"""
|
|
Convert parameter value to the correct type based on a list of possible types.
|
|
Tries each type in order until one succeeds.
|
|
|
|
Args:
|
|
value: The string value to convert
|
|
param_types: List of possible type strings
|
|
|
|
Returns:
|
|
The converted value
|
|
"""
|
|
# Check if the VALUE itself indicates null (not just if null is allowed)
|
|
if value.lower() in ("null", "none", "nil"):
|
|
return None
|
|
|
|
# Normalize types
|
|
normalized_types = [t.lower() for t in param_types]
|
|
|
|
# Try each type in order of preference (most specific first, string as fallback)
|
|
# Priority: integer > number > boolean > object > array > string
|
|
type_priority = [
|
|
"integer",
|
|
"int",
|
|
"number",
|
|
"float",
|
|
"boolean",
|
|
"bool",
|
|
"object",
|
|
"array",
|
|
"string",
|
|
"str",
|
|
"text",
|
|
]
|
|
|
|
for param_type in type_priority:
|
|
if param_type not in normalized_types:
|
|
continue
|
|
|
|
if param_type in ["string", "str", "text"]:
|
|
return value
|
|
elif param_type in ["integer", "int"]:
|
|
try:
|
|
return int(value)
|
|
except (ValueError, TypeError):
|
|
continue
|
|
elif param_type in ["number", "float"]:
|
|
try:
|
|
val = float(value)
|
|
return val if val != int(val) else int(val)
|
|
except (ValueError, TypeError):
|
|
continue
|
|
elif param_type in ["boolean", "bool"]:
|
|
lower_val = value.lower().strip()
|
|
if lower_val in ["true", "1", "yes", "on"]:
|
|
return True
|
|
elif lower_val in ["false", "0", "no", "off"]:
|
|
return False
|
|
continue
|
|
elif param_type in ["object", "array"]:
|
|
try:
|
|
return json.loads(value)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# Fallback: try JSON parse, then return as string
|
|
try:
|
|
return json.loads(value)
|
|
except json.JSONDecodeError:
|
|
return value
|
|
|
|
def _get_param_types_from_config(
|
|
self, param_name: str, param_config: dict
|
|
) -> list[str]:
|
|
"""
|
|
Get parameter types from parameter configuration.
|
|
Handles anyOf, oneOf, allOf, and direct type definitions.
|
|
|
|
Args:
|
|
param_name: The name of the parameter
|
|
param_config: The properties dict from the tool schema
|
|
|
|
Returns:
|
|
List of type strings
|
|
"""
|
|
if param_name not in param_config:
|
|
return ["string"]
|
|
|
|
param_schema = param_config[param_name]
|
|
if not isinstance(param_schema, dict):
|
|
return ["string"]
|
|
|
|
return self._extract_types_from_schema(param_schema)
|
|
|
|
def _parse_single_invoke(
|
|
self, invoke_str: str, tools: list | None
|
|
) -> ToolCall | None:
|
|
"""Parse a single <invoke> block."""
|
|
# Extract function name
|
|
name_match = re.search(r"^([^>]+)", invoke_str)
|
|
if not name_match:
|
|
return None
|
|
|
|
function_name = self._extract_name(name_match.group(1))
|
|
|
|
# Get parameter configuration
|
|
param_config = {}
|
|
if tools:
|
|
for tool in tools:
|
|
if (
|
|
hasattr(tool, "function")
|
|
and tool.function.name == function_name
|
|
and hasattr(tool.function, "parameters")
|
|
):
|
|
params = tool.function.parameters
|
|
if isinstance(params, dict) and "properties" in params:
|
|
param_config = params["properties"]
|
|
break
|
|
|
|
# Extract parameters
|
|
param_dict = {}
|
|
for match in self.parameter_complete_regex.findall(invoke_str):
|
|
param_match = re.search(r"^([^>]+)>(.*)", match, re.DOTALL)
|
|
if param_match:
|
|
param_name = self._extract_name(param_match.group(1))
|
|
param_value = param_match.group(2).strip()
|
|
|
|
# Get parameter types (supports anyOf/oneOf/allOf)
|
|
param_type = self._get_param_types_from_config(param_name, param_config)
|
|
|
|
# Convert value
|
|
param_dict[param_name] = self._convert_param_value_with_types(
|
|
param_value, param_type
|
|
)
|
|
|
|
return ToolCall(
|
|
type="function",
|
|
function=FunctionCall(
|
|
name=function_name,
|
|
arguments=json.dumps(param_dict, ensure_ascii=False),
|
|
),
|
|
)
|
|
|
|
def _extract_delta_tool_calls(
|
|
self,
|
|
current_text: str,
|
|
request: ChatCompletionRequest | None,
|
|
) -> list[DeltaToolCall]:
|
|
"""Extract DeltaToolCalls from newly completed <invoke> blocks.
|
|
|
|
Tracks progress via ``current_tool_index`` so each block is
|
|
extracted exactly once across successive streaming calls.
|
|
"""
|
|
complete_invokes = self.invoke_complete_regex.findall(current_text)
|
|
delta_tool_calls: list[DeltaToolCall] = []
|
|
|
|
while len(complete_invokes) > self.current_tool_index:
|
|
invoke_str = complete_invokes[self.current_tool_index]
|
|
tool_call = self._parse_single_invoke(
|
|
invoke_str,
|
|
request.tools if request else None,
|
|
)
|
|
if not tool_call:
|
|
self.current_tool_index += 1
|
|
continue
|
|
|
|
args_json = tool_call.function.arguments
|
|
idx = self.current_tool_index
|
|
self.current_tool_index += 1
|
|
|
|
self.prev_tool_call_arr.append(
|
|
{
|
|
"name": tool_call.function.name,
|
|
"arguments": json.loads(args_json),
|
|
}
|
|
)
|
|
self.streamed_args_for_tool.append(args_json)
|
|
delta_tool_calls.append(
|
|
DeltaToolCall(
|
|
index=idx,
|
|
id=self._generate_tool_call_id(),
|
|
function=DeltaFunctionCall(
|
|
name=tool_call.function.name,
|
|
arguments=args_json,
|
|
),
|
|
type="function",
|
|
)
|
|
)
|
|
|
|
return delta_tool_calls
|
|
|
|
def extract_tool_calls(
|
|
self,
|
|
model_output: str,
|
|
request: ChatCompletionRequest,
|
|
) -> ExtractedToolCallInformation:
|
|
"""Extract tool calls from complete model output (non-streaming)."""
|
|
# Quick check
|
|
if self.tool_call_start_token not in model_output:
|
|
return ExtractedToolCallInformation(
|
|
tools_called=False, tool_calls=[], content=model_output
|
|
)
|
|
|
|
try:
|
|
tool_calls = []
|
|
|
|
# Find all complete tool_call blocks
|
|
for tool_call_match in self.tool_call_complete_regex.findall(model_output):
|
|
# Find all invokes within this tool_call
|
|
for invoke_match in self.invoke_complete_regex.findall(tool_call_match):
|
|
tool_call = self._parse_single_invoke(
|
|
invoke_match, request.tools if request else None
|
|
)
|
|
if tool_call:
|
|
tool_calls.append(tool_call)
|
|
|
|
if not tool_calls:
|
|
return ExtractedToolCallInformation(
|
|
tools_called=False, tool_calls=[], content=model_output
|
|
)
|
|
|
|
# Update prev_tool_call_arr
|
|
self.prev_tool_call_arr.clear()
|
|
for tool_call in tool_calls:
|
|
self.prev_tool_call_arr.append(
|
|
{
|
|
"name": tool_call.function.name,
|
|
"arguments": tool_call.function.arguments,
|
|
}
|
|
)
|
|
|
|
# Extract content before first tool call
|
|
first_tool_idx = model_output.find(self.tool_call_start_token)
|
|
content = model_output[:first_tool_idx] if first_tool_idx > 0 else None
|
|
|
|
return ExtractedToolCallInformation(
|
|
tools_called=True, tool_calls=tool_calls, content=content
|
|
)
|
|
|
|
except Exception:
|
|
logger.exception("Error extracting tool calls")
|
|
return ExtractedToolCallInformation(
|
|
tools_called=False, tool_calls=[], content=model_output
|
|
)
|
|
|
|
def extract_tool_calls_streaming(
|
|
self,
|
|
previous_text: str,
|
|
current_text: str,
|
|
delta_text: str,
|
|
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
|
|
current_token_ids: Sequence[int], # pylint: disable=unused-argument
|
|
delta_token_ids: Sequence[int],
|
|
request: ChatCompletionRequest,
|
|
) -> DeltaMessage | None:
|
|
"""Extract tool calls from streaming model output.
|
|
|
|
Uses a buffer-until-complete-invoke strategy: tokens are buffered
|
|
until a complete ``<invoke>...</invoke>`` block is available, then
|
|
parsed and emitted in one shot.
|
|
"""
|
|
|
|
start_in_text = self.tool_call_start_token in delta_text
|
|
start_in_ids = self.tool_call_start_token_id in delta_token_ids
|
|
tool_call_starting = start_in_text or start_in_ids
|
|
# Reset state on new request (parser is reused) or new tool-call block.
|
|
if not previous_text or tool_call_starting:
|
|
self.current_tool_index = 0
|
|
self.prev_tool_call_arr.clear()
|
|
self.streamed_args_for_tool.clear()
|
|
self.is_tool_call_started = tool_call_starting
|
|
|
|
# Pass through content before any tool call.
|
|
if not self.is_tool_call_started:
|
|
return DeltaMessage(content=delta_text) if delta_text else None
|
|
|
|
# Capture content before the start token.
|
|
content_before = None
|
|
if start_in_text:
|
|
before = delta_text[: delta_text.index(self.tool_call_start_token)]
|
|
content_before = before or None
|
|
|
|
# Extract newly completed <invoke> blocks as DeltaToolCalls.
|
|
delta_tool_calls = self._extract_delta_tool_calls(current_text, request)
|
|
|
|
if delta_tool_calls or content_before:
|
|
return DeltaMessage(
|
|
content=content_before,
|
|
tool_calls=delta_tool_calls,
|
|
)
|
|
|
|
# EOS and </minimax:tool_call> both arrive as special tokens with
|
|
# no decoded text. Return non-None for EOS so the serving framework
|
|
# reaches the finish-reason handling path instead of skipping.
|
|
if (
|
|
not delta_text
|
|
and delta_token_ids
|
|
and self.prev_tool_call_arr
|
|
and self.tool_call_end_token_id not in delta_token_ids
|
|
):
|
|
return DeltaMessage(content="")
|
|
|
|
return None
|