Some Kimi K2.5 model variants (nvidia/Kimi-K2.5-NVFP4) omit <|tool_calls_section_begin|> and go directly to <|tool_call_begin|>. The tool parser was only looking for section-level markers, so these tool calls were forwarded as raw content text instead of being parsed. Fix: _find_section_start and _find_section_start_end now fall back to <|tool_call_begin|> as a section start when no section-level marker is found. The section end falls back to <|tool_call_end|>.
590 lines
23 KiB
Python
590 lines
23 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Kimi-K2 Tool Call Parser — re-parse-and-diff version.
|
|
|
|
Adapted from the GLM-4/DeepSeek-V3.2 streaming fix to make the
|
|
streaming path robust against multi-token deltas produced by MTP
|
|
speculative decoding.
|
|
|
|
Instead of counting start/end tokens to maintain an incremental state
|
|
machine, the streaming path re-parses the *entire* current_text on
|
|
every call, finds all <|tool_call_begin|> regions (complete and
|
|
in-progress), extracts the JSON arguments for each, and diffs against
|
|
what was previously sent.
|
|
|
|
Key changes vs. the upstream token-count parser:
|
|
1. No token-count state machine — the parser is stateless w.r.t.
|
|
how many tokens arrived per step.
|
|
2. Content forwarding uses delta_text (not re-parsed current_text)
|
|
so reasoning text is never re-emitted as content.
|
|
3. _extract_tool_call_regions() finds both complete and incomplete
|
|
tool-call blocks, enabling argument streaming.
|
|
4. _compute_args_diff() emits only newly-added characters.
|
|
5. Handles singular/plural section marker variants.
|
|
6. Returns empty deltas inside open sections to keep the stream
|
|
alive while tool call tokens are still arriving.
|
|
|
|
Drop-in replacement: same class name, same interface.
|
|
|
|
Example tool call format::
|
|
|
|
<|tool_calls_section_begin|>
|
|
<|tool_call_begin|>
|
|
functions.get_weather:0
|
|
<|tool_call_argument_begin|>
|
|
{"location": "杭州", "date": "2024-01-16"}
|
|
<|tool_call_end|>
|
|
<|tool_call_begin|>
|
|
functions.get_time:1
|
|
<|tool_call_argument_begin|>
|
|
{"timezone": "Asia/Shanghai"}
|
|
<|tool_call_end|>
|
|
<|tool_calls_section_end|>
|
|
"""
|
|
|
|
from collections.abc import Sequence
|
|
from typing import Any
|
|
|
|
import regex as re
|
|
|
|
from vllm.entrypoints.openai.chat_completion.protocol import (
|
|
ChatCompletionRequest,
|
|
)
|
|
from vllm.entrypoints.openai.engine.protocol import (
|
|
DeltaFunctionCall,
|
|
DeltaMessage,
|
|
DeltaToolCall,
|
|
ExtractedToolCallInformation,
|
|
FunctionCall,
|
|
ToolCall,
|
|
)
|
|
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
|
from vllm.logger import init_logger
|
|
from vllm.tokenizers import TokenizerLike
|
|
from vllm.tool_parsers.abstract_tool_parser import (
|
|
Tool,
|
|
ToolParser,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Utility — inlined to avoid import issues across vLLM versions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def partial_tag_overlap(text: str, tag: str) -> int:
|
|
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
|
|
|
|
E.g. text ending in ``"<|tool_call"`` returns 11 when tag is
|
|
``"<|tool_call_begin|>"``. Returns 0 when there is no overlap.
|
|
"""
|
|
max_check = min(len(tag) - 1, len(text))
|
|
for k in range(max_check, 0, -1):
|
|
if text.endswith(tag[:k]):
|
|
return k
|
|
return 0
|
|
|
|
|
|
class KimiK2ToolParser(ToolParser):
|
|
"""Re-parse-and-diff tool parser for Kimi-K2 format.
|
|
|
|
On every streaming call the parser re-parses ``current_text`` to
|
|
find tool-call regions, extracts the JSON arguments for each, and
|
|
diffs against what was previously sent. This is robust against
|
|
multi-token deltas from MTP / EAGLE speculative decoding.
|
|
"""
|
|
|
|
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
|
super().__init__(tokenizer, tools)
|
|
|
|
# ----- Tag constants -----
|
|
# Section wrappers (support singular & plural variants)
|
|
self.tool_calls_section_start_variants: list[str] = [
|
|
"<|tool_calls_section_begin|>",
|
|
"<|tool_call_section_begin|>",
|
|
]
|
|
self.tool_calls_section_end_variants: list[str] = [
|
|
"<|tool_calls_section_end|>",
|
|
"<|tool_call_section_end|>",
|
|
]
|
|
# Some model variants omit the section-level marker and go
|
|
# directly to <|tool_call_begin|>. Treat it as a fallback.
|
|
self._fallback_section_start: str = "<|tool_call_begin|>"
|
|
# Primary variant for ToolParser base class / adjust_request
|
|
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
|
|
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
|
|
|
|
# Individual tool-call markers
|
|
self.tool_call_start_token: str = "<|tool_call_begin|>"
|
|
self.tool_call_end_token: str = "<|tool_call_end|>"
|
|
self.tool_call_arg_begin: str = "<|tool_call_argument_begin|>"
|
|
|
|
# ----- Compiled regexes -----
|
|
# Complete tool call block.
|
|
self.tool_call_regex = re.compile(
|
|
r"<\|tool_call_begin\|>\s*"
|
|
r"(?P<tool_call_id>[^<]+:\d+)\s*"
|
|
r"<\|tool_call_argument_begin\|>\s*"
|
|
r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
|
|
r"<\|tool_call_end\|>",
|
|
re.DOTALL,
|
|
)
|
|
|
|
# For extracting tool ID from the start of a tool-call region.
|
|
self.tool_id_regex = re.compile(
|
|
r"\s*(?P<tool_id>[^\s<]+:\d+)\s*", re.DOTALL
|
|
)
|
|
|
|
# ----- Streaming state (reset per request) -----
|
|
self._tool_call_ids: list[str] = []
|
|
self.streamed_args_for_tool: list[str] = []
|
|
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
|
self.current_tool_id: int = -1
|
|
|
|
if not self.model_tokenizer:
|
|
raise ValueError(
|
|
"The model tokenizer must be passed to the ToolParser "
|
|
"constructor during construction."
|
|
)
|
|
|
|
# Validate that the primary section tokens exist in vocab.
|
|
self.tool_calls_start_token_id = self.vocab.get(
|
|
self.tool_calls_start_token
|
|
)
|
|
self.tool_calls_end_token_id = self.vocab.get(
|
|
self.tool_calls_end_token
|
|
)
|
|
self.tool_call_start_token_id = self.vocab.get(
|
|
self.tool_call_start_token
|
|
)
|
|
self.tool_call_end_token_id = self.vocab.get(
|
|
self.tool_call_end_token
|
|
)
|
|
|
|
if (
|
|
self.tool_calls_start_token_id is None
|
|
or self.tool_calls_end_token_id is None
|
|
):
|
|
raise RuntimeError(
|
|
"Kimi-K2 Tool parser could not locate tool call start/end "
|
|
"tokens in the tokenizer!"
|
|
)
|
|
|
|
logger.debug(
|
|
"Successfully initialized %s", self.__class__.__name__
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Request adjustment
|
|
# ------------------------------------------------------------------
|
|
|
|
def adjust_request(
|
|
self, request: ChatCompletionRequest | ResponsesRequest
|
|
) -> ChatCompletionRequest | ResponsesRequest:
|
|
request = super().adjust_request(request)
|
|
if request.tools and request.tool_choice != "none":
|
|
# Ensure tool-call tokens (<|tool_calls_section_begin|>,
|
|
# <|tool_call_begin|>, etc.) are not stripped during decoding.
|
|
request.skip_special_tokens = False
|
|
return request
|
|
|
|
# ------------------------------------------------------------------
|
|
# Helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _tools_enabled(request: ChatCompletionRequest) -> bool:
|
|
try:
|
|
tools = getattr(request, "tools", None)
|
|
tool_choice = getattr(request, "tool_choice", None)
|
|
return bool(tools) and tool_choice != "none"
|
|
except Exception:
|
|
logger.exception("Failed to determine if tools are enabled.")
|
|
return False
|
|
|
|
@staticmethod
|
|
def _parse_tool_id(raw_id: str) -> tuple[str, str]:
|
|
"""Parse ``'functions.get_weather:0'`` → ``('get_weather', 'functions.get_weather:0')``."""
|
|
raw_id = raw_id.strip()
|
|
function_name = raw_id.split(":")[0].split(".")[-1]
|
|
return function_name, raw_id
|
|
|
|
def _find_section_start(self, text: str) -> int:
|
|
"""Return the index of the first section-start marker, or -1.
|
|
|
|
Falls back to <|tool_call_begin|> if no section-level marker
|
|
is found. Some model variants skip <|tool_calls_section_begin|>
|
|
and go directly to <|tool_call_begin|>.
|
|
"""
|
|
best = -1
|
|
for variant in self.tool_calls_section_start_variants:
|
|
idx = text.find(variant)
|
|
if idx != -1 and (best == -1 or idx < best):
|
|
best = idx
|
|
# Fallback: if no section-level marker found, look for
|
|
# <|tool_call_begin|> directly.
|
|
if best == -1 and self._fallback_section_start:
|
|
idx = text.find(self._fallback_section_start)
|
|
if idx != -1:
|
|
best = idx
|
|
return best
|
|
|
|
def _find_section_start_end(self, text: str) -> tuple[int, int]:
|
|
"""Return (start_of_inner, end_of_inner) for the section region.
|
|
|
|
*start_of_inner* points just past the section-start marker.
|
|
*end_of_inner* is the index of the section-end marker, or -1
|
|
if the section is still open.
|
|
|
|
Falls back to <|tool_call_begin|> if no section-level marker
|
|
is found.
|
|
"""
|
|
for variant in self.tool_calls_section_start_variants:
|
|
idx = text.find(variant)
|
|
if idx != -1:
|
|
inner_start = idx + len(variant)
|
|
# Look for end marker
|
|
for end_variant in self.tool_calls_section_end_variants:
|
|
end_idx = text.find(end_variant, inner_start)
|
|
if end_idx != -1:
|
|
return inner_start, end_idx
|
|
return inner_start, -1
|
|
|
|
# Fallback: no section-level marker found. Look for
|
|
# <|tool_call_begin|> directly as the section start.
|
|
if self._fallback_section_start:
|
|
idx = text.find(self._fallback_section_start)
|
|
if idx != -1:
|
|
inner_start = idx + len(self._fallback_section_start)
|
|
# Look for <|tool_call_end|> as the section end
|
|
end_marker = self._fallback_section_start.replace("begin", "end")
|
|
end_idx = text.find(end_marker, inner_start)
|
|
if end_idx != -1:
|
|
return inner_start, end_idx
|
|
return inner_start, -1
|
|
|
|
return -1, -1
|
|
|
|
# ------------------------------------------------------------------
|
|
# Non-streaming extraction (logic preserved from original)
|
|
# ------------------------------------------------------------------
|
|
|
|
def extract_tool_calls(
|
|
self,
|
|
model_output: str,
|
|
request: ChatCompletionRequest,
|
|
) -> ExtractedToolCallInformation:
|
|
if self.tool_calls_start_token not in model_output:
|
|
return ExtractedToolCallInformation(
|
|
tools_called=False, tool_calls=[], content=model_output
|
|
)
|
|
|
|
try:
|
|
function_call_tuples = self.tool_call_regex.findall(model_output)
|
|
logger.debug("function_call_tuples: %s", function_call_tuples)
|
|
|
|
tool_calls: list[ToolCall] = []
|
|
for match in function_call_tuples:
|
|
function_id, function_args = match
|
|
function_name, full_id = self._parse_tool_id(function_id)
|
|
tool_calls.append(
|
|
ToolCall(
|
|
id=full_id,
|
|
type="function",
|
|
function=FunctionCall(
|
|
name=function_name,
|
|
arguments=function_args,
|
|
),
|
|
)
|
|
)
|
|
|
|
content_end = self._find_section_start(model_output)
|
|
content = model_output[:content_end] if content_end > 0 else None
|
|
|
|
return ExtractedToolCallInformation(
|
|
tools_called=True,
|
|
tool_calls=tool_calls,
|
|
content=content if content else None,
|
|
)
|
|
|
|
except Exception:
|
|
logger.exception("Error in extracting tool call from response.")
|
|
return ExtractedToolCallInformation(
|
|
tools_called=False, tool_calls=[], content=model_output
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Streaming helpers — re-parse-and-diff
|
|
# ------------------------------------------------------------------
|
|
|
|
def _reset_streaming_state(self) -> None:
|
|
self._tool_call_ids.clear()
|
|
self.streamed_args_for_tool.clear()
|
|
self.prev_tool_call_arr.clear()
|
|
self.current_tool_id = -1
|
|
|
|
def _extract_tool_call_regions(
|
|
self, text: str
|
|
) -> list[tuple[str, bool]]:
|
|
"""Find all ``<|tool_call_begin|>`` … ``<|tool_call_end|>``
|
|
blocks inside the tool-calls section.
|
|
|
|
Returns a list of ``(inner_text, is_complete)`` tuples.
|
|
*inner_text* is everything between the tool-call open and close
|
|
tags (or end-of-available-text for the last partial block).
|
|
"""
|
|
results: list[tuple[str, bool]] = []
|
|
|
|
# Find the section region.
|
|
inner_start, inner_end = self._find_section_start_end(text)
|
|
if inner_start == -1:
|
|
return results
|
|
|
|
region = text[inner_start:inner_end] if inner_end != -1 else text[inner_start:]
|
|
|
|
pos = 0
|
|
while pos < len(region):
|
|
tc_start = region.find(self.tool_call_start_token, pos)
|
|
if tc_start == -1:
|
|
break
|
|
|
|
body_start = tc_start + len(self.tool_call_start_token)
|
|
tc_end = region.find(self.tool_call_end_token, body_start)
|
|
|
|
if tc_end != -1:
|
|
body = region[body_start:tc_end]
|
|
results.append((body, True))
|
|
pos = tc_end + len(self.tool_call_end_token)
|
|
else:
|
|
# Incomplete — still being generated.
|
|
body = region[body_start:]
|
|
overlap = partial_tag_overlap(body, self.tool_call_end_token)
|
|
if overlap:
|
|
body = body[:-overlap]
|
|
results.append((body, False))
|
|
break
|
|
|
|
return results
|
|
|
|
def _parse_tool_call_body(
|
|
self, body: str, is_complete: bool
|
|
) -> tuple[str | None, str | None, str]:
|
|
"""Parse a tool-call body into (func_name, tool_id, args_so_far).
|
|
|
|
The body looks like::
|
|
|
|
functions.get_weather:0
|
|
<|tool_call_argument_begin|>
|
|
{"location": "杭州"}
|
|
|
|
Returns ``(None, None, "")`` if the body doesn't contain enough
|
|
information yet (e.g. the tool ID is still arriving).
|
|
"""
|
|
# Extract tool ID (everything before <|tool_call_argument_begin|>
|
|
# or end of string).
|
|
arg_begin_idx = body.find(self.tool_call_arg_begin)
|
|
|
|
if arg_begin_idx != -1:
|
|
id_portion = body[:arg_begin_idx]
|
|
args_portion = body[arg_begin_idx + len(self.tool_call_arg_begin):]
|
|
else:
|
|
id_portion = body
|
|
args_portion = ""
|
|
|
|
# Try to extract the tool ID.
|
|
id_match = self.tool_id_regex.match(id_portion)
|
|
if not id_match:
|
|
# Not enough tokens yet to identify the tool.
|
|
return None, None, ""
|
|
|
|
raw_id = id_match.group("tool_id")
|
|
func_name, full_id = self._parse_tool_id(raw_id)
|
|
|
|
# Build args string.
|
|
args = args_portion.strip()
|
|
|
|
if is_complete:
|
|
# For a complete block, args is the final JSON.
|
|
return func_name, full_id, args
|
|
else:
|
|
# For a partial block, strip any trailing partial-tag overlap
|
|
# against tool_call_end (already done in caller), but also
|
|
# check for partial overlap against tool_call_argument_begin
|
|
# in case it hasn't fully arrived yet.
|
|
if arg_begin_idx == -1:
|
|
# No argument section yet.
|
|
overlap = partial_tag_overlap(
|
|
id_portion, self.tool_call_arg_begin
|
|
)
|
|
if overlap:
|
|
# The tag is still arriving — we have the name but
|
|
# no args yet.
|
|
pass
|
|
return func_name, full_id, ""
|
|
|
|
return func_name, full_id, args
|
|
|
|
def _compute_args_diff(self, index: int, args_so_far: str) -> str | None:
|
|
"""Return only the characters in *args_so_far* that haven't been
|
|
sent yet, or ``None`` if there's nothing new."""
|
|
prev = self.streamed_args_for_tool[index]
|
|
if not args_so_far or len(args_so_far) <= len(prev):
|
|
return None
|
|
diff = args_so_far[len(prev):]
|
|
self.streamed_args_for_tool[index] = args_so_far
|
|
self.prev_tool_call_arr[index]["arguments"] = args_so_far
|
|
return diff
|
|
|
|
def _ensure_tool_state_for(self, index: int) -> None:
|
|
"""Grow the streaming-state arrays so *index* is valid."""
|
|
while len(self._tool_call_ids) <= index:
|
|
self._tool_call_ids.append("")
|
|
while len(self.streamed_args_for_tool) <= index:
|
|
self.streamed_args_for_tool.append("")
|
|
while len(self.prev_tool_call_arr) <= index:
|
|
self.prev_tool_call_arr.append({})
|
|
|
|
# ------------------------------------------------------------------
|
|
# Main streaming entry point
|
|
# ------------------------------------------------------------------
|
|
|
|
def extract_tool_calls_streaming(
|
|
self,
|
|
previous_text: str,
|
|
current_text: str,
|
|
delta_text: str,
|
|
previous_token_ids: Sequence[int],
|
|
current_token_ids: Sequence[int],
|
|
delta_token_ids: Sequence[int],
|
|
request: ChatCompletionRequest,
|
|
) -> DeltaMessage | None:
|
|
"""Extract tool calls from streaming output.
|
|
|
|
Hybrid approach:
|
|
- **Content forwarding** uses ``delta_text`` (same as the
|
|
original parser) so we never re-emit text that the reasoning
|
|
parser already handled.
|
|
- **Tool call detection** re-parses ``current_text`` on every
|
|
call (the re-parse-and-diff approach) so it's agnostic to
|
|
how many tokens arrived per step — robust against MTP.
|
|
"""
|
|
logger.debug("delta_text: %s", delta_text)
|
|
logger.debug("delta_token_ids: %s", delta_token_ids)
|
|
|
|
# First chunk of a new stream — reset state.
|
|
if not previous_text:
|
|
self._reset_streaming_state()
|
|
|
|
# If tools aren't enabled, just forward content.
|
|
if not self._tools_enabled(request):
|
|
return DeltaMessage(content=delta_text) if delta_text else None
|
|
|
|
# ── Determine section state from full text (MTP-safe) ──
|
|
inner_start, inner_end = self._find_section_start_end(current_text)
|
|
in_open_section = inner_start != -1 and inner_end == -1
|
|
|
|
# Was the section already open in previous_text?
|
|
prev_inner_start, _ = self._find_section_start_end(previous_text)
|
|
section_existed_before = prev_inner_start != -1
|
|
|
|
# ── Re-parse tool calls from current_text (MTP-safe) ──
|
|
regions = self._extract_tool_call_regions(current_text)
|
|
tool_call_deltas: list[DeltaToolCall] = []
|
|
|
|
for i, (body, is_complete) in enumerate(regions):
|
|
self._ensure_tool_state_for(i)
|
|
|
|
func_name, tool_id, args_so_far = self._parse_tool_call_body(
|
|
body, is_complete
|
|
)
|
|
if func_name is None:
|
|
# Not enough data to identify the tool yet.
|
|
break
|
|
|
|
# Emit the tool name (once per tool call).
|
|
if "name" not in self.prev_tool_call_arr[i]:
|
|
self.prev_tool_call_arr[i]["name"] = func_name
|
|
self._tool_call_ids[i] = tool_id or ""
|
|
tool_call_deltas.append(
|
|
DeltaToolCall(
|
|
index=i,
|
|
id=tool_id,
|
|
type="function",
|
|
function=DeltaFunctionCall(
|
|
name=func_name,
|
|
arguments="",
|
|
),
|
|
)
|
|
)
|
|
|
|
# Diff the arguments and emit any new characters.
|
|
diff = self._compute_args_diff(i, args_so_far)
|
|
if diff:
|
|
tool_call_deltas.append(
|
|
DeltaToolCall(
|
|
index=i,
|
|
function=DeltaFunctionCall(arguments=diff),
|
|
)
|
|
)
|
|
|
|
if regions:
|
|
self.current_tool_id = len(regions) - 1
|
|
|
|
# ── Emit results ──
|
|
|
|
# Case 1: We have tool call updates — emit them.
|
|
if tool_call_deltas:
|
|
return DeltaMessage(tool_calls=tool_call_deltas)
|
|
|
|
# Case 2: No tool section has started yet — forward delta_text
|
|
# as content. The reasoning parser handles the reasoning/content
|
|
# split; we just pass through whatever delta the serving layer
|
|
# gave us.
|
|
if inner_start == -1:
|
|
return DeltaMessage(content=delta_text) if delta_text else None
|
|
|
|
# Case 3: The section just appeared in this delta. Extract any
|
|
# content that came before the section marker in this delta
|
|
# (e.g. "Let me check.<|tool_calls_section_begin|>").
|
|
if not section_existed_before:
|
|
section_start_in_text = self._find_section_start(current_text)
|
|
pre_section = current_text[len(previous_text):section_start_in_text]
|
|
if pre_section.strip():
|
|
return DeltaMessage(content=pre_section)
|
|
# No real content before the section — return None instead of
|
|
# an empty-string delta. Empty content deltas confuse clients
|
|
# that distinguish content=null from content="".
|
|
return None
|
|
|
|
# Case 4: Inside an open tool section but tool calls aren't
|
|
# parseable yet — return None. The serving layer will emit
|
|
# its own keep-alive if needed; we should not emit empty-string
|
|
# content deltas that pollute the response.
|
|
if in_open_section:
|
|
return None
|
|
|
|
# Case 5: Section is closed and we're past it — forward any
|
|
# new content that appeared after the section end marker.
|
|
if inner_end != -1:
|
|
for variant in self.tool_calls_section_end_variants:
|
|
end_marker_pos = current_text.find(variant, inner_start)
|
|
if end_marker_pos != -1:
|
|
after_section = current_text[
|
|
end_marker_pos + len(variant):
|
|
]
|
|
# Only emit what's new (not previously seen)
|
|
prev_after_len = 0
|
|
prev_end_pos = previous_text.find(variant)
|
|
if prev_end_pos != -1:
|
|
prev_after_len = len(
|
|
previous_text[prev_end_pos + len(variant):]
|
|
)
|
|
new_after = after_section[prev_after_len:]
|
|
if new_after:
|
|
return DeltaMessage(content=new_after)
|
|
break
|
|
return None
|
|
|
|
return None |