Files
vllm-kimi25-eagle/kimi_k2_tool_parser.py
biondizzle 3ee933951c Tool parser: fallback to <|tool_call_begin|> when no section marker
Some Kimi K2.5 model variants (nvidia/Kimi-K2.5-NVFP4) omit
<|tool_calls_section_begin|> and go directly to <|tool_call_begin|>.
The tool parser was only looking for section-level markers, so these
tool calls were forwarded as raw content text instead of being parsed.

Fix: _find_section_start and _find_section_start_end now fall back to
<|tool_call_begin|> as a section start when no section-level marker
is found. The section end falls back to <|tool_call_end|>.
2026-04-14 11:25:11 +00:00

590 lines
23 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Kimi-K2 Tool Call Parser — re-parse-and-diff version.
Adapted from the GLM-4/DeepSeek-V3.2 streaming fix to make the
streaming path robust against multi-token deltas produced by MTP
speculative decoding.
Instead of counting start/end tokens to maintain an incremental state
machine, the streaming path re-parses the *entire* current_text on
every call, finds all <|tool_call_begin|> regions (complete and
in-progress), extracts the JSON arguments for each, and diffs against
what was previously sent.
Key changes vs. the upstream token-count parser:
1. No token-count state machine — the parser is stateless w.r.t.
how many tokens arrived per step.
2. Content forwarding uses delta_text (not re-parsed current_text)
so reasoning text is never re-emitted as content.
3. _extract_tool_call_regions() finds both complete and incomplete
tool-call blocks, enabling argument streaming.
4. _compute_args_diff() emits only newly-added characters.
5. Handles singular/plural section marker variants.
6. Returns empty deltas inside open sections to keep the stream
alive while tool call tokens are still arriving.
Drop-in replacement: same class name, same interface.
Example tool call format::
<|tool_calls_section_begin|>
<|tool_call_begin|>
functions.get_weather:0
<|tool_call_argument_begin|>
{"location": "杭州", "date": "2024-01-16"}
<|tool_call_end|>
<|tool_call_begin|>
functions.get_time:1
<|tool_call_argument_begin|>
{"timezone": "Asia/Shanghai"}
<|tool_call_end|>
<|tool_calls_section_end|>
"""
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
logger = init_logger(__name__)
# ---------------------------------------------------------------------------
# Utility — inlined to avoid import issues across vLLM versions
# ---------------------------------------------------------------------------
def partial_tag_overlap(text: str, tag: str) -> int:
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
E.g. text ending in ``"<|tool_call"`` returns 11 when tag is
``"<|tool_call_begin|>"``. Returns 0 when there is no overlap.
"""
max_check = min(len(tag) - 1, len(text))
for k in range(max_check, 0, -1):
if text.endswith(tag[:k]):
return k
return 0
class KimiK2ToolParser(ToolParser):
"""Re-parse-and-diff tool parser for Kimi-K2 format.
On every streaming call the parser re-parses ``current_text`` to
find tool-call regions, extracts the JSON arguments for each, and
diffs against what was previously sent. This is robust against
multi-token deltas from MTP / EAGLE speculative decoding.
"""
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# ----- Tag constants -----
# Section wrappers (support singular & plural variants)
self.tool_calls_section_start_variants: list[str] = [
"<|tool_calls_section_begin|>",
"<|tool_call_section_begin|>",
]
self.tool_calls_section_end_variants: list[str] = [
"<|tool_calls_section_end|>",
"<|tool_call_section_end|>",
]
# Some model variants omit the section-level marker and go
# directly to <|tool_call_begin|>. Treat it as a fallback.
self._fallback_section_start: str = "<|tool_call_begin|>"
# Primary variant for ToolParser base class / adjust_request
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
# Individual tool-call markers
self.tool_call_start_token: str = "<|tool_call_begin|>"
self.tool_call_end_token: str = "<|tool_call_end|>"
self.tool_call_arg_begin: str = "<|tool_call_argument_begin|>"
# ----- Compiled regexes -----
# Complete tool call block.
self.tool_call_regex = re.compile(
r"<\|tool_call_begin\|>\s*"
r"(?P<tool_call_id>[^<]+:\d+)\s*"
r"<\|tool_call_argument_begin\|>\s*"
r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
r"<\|tool_call_end\|>",
re.DOTALL,
)
# For extracting tool ID from the start of a tool-call region.
self.tool_id_regex = re.compile(
r"\s*(?P<tool_id>[^\s<]+:\d+)\s*", re.DOTALL
)
# ----- Streaming state (reset per request) -----
self._tool_call_ids: list[str] = []
self.streamed_args_for_tool: list[str] = []
self.prev_tool_call_arr: list[dict[str, Any]] = []
self.current_tool_id: int = -1
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
# Validate that the primary section tokens exist in vocab.
self.tool_calls_start_token_id = self.vocab.get(
self.tool_calls_start_token
)
self.tool_calls_end_token_id = self.vocab.get(
self.tool_calls_end_token
)
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token
)
self.tool_call_end_token_id = self.vocab.get(
self.tool_call_end_token
)
if (
self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None
):
raise RuntimeError(
"Kimi-K2 Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
logger.debug(
"Successfully initialized %s", self.__class__.__name__
)
# ------------------------------------------------------------------
# Request adjustment
# ------------------------------------------------------------------
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# Ensure tool-call tokens (<|tool_calls_section_begin|>,
# <|tool_call_begin|>, etc.) are not stripped during decoding.
request.skip_special_tokens = False
return request
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _tools_enabled(request: ChatCompletionRequest) -> bool:
try:
tools = getattr(request, "tools", None)
tool_choice = getattr(request, "tool_choice", None)
return bool(tools) and tool_choice != "none"
except Exception:
logger.exception("Failed to determine if tools are enabled.")
return False
@staticmethod
def _parse_tool_id(raw_id: str) -> tuple[str, str]:
"""Parse ``'functions.get_weather:0'`` → ``('get_weather', 'functions.get_weather:0')``."""
raw_id = raw_id.strip()
function_name = raw_id.split(":")[0].split(".")[-1]
return function_name, raw_id
def _find_section_start(self, text: str) -> int:
"""Return the index of the first section-start marker, or -1.
Falls back to <|tool_call_begin|> if no section-level marker
is found. Some model variants skip <|tool_calls_section_begin|>
and go directly to <|tool_call_begin|>.
"""
best = -1
for variant in self.tool_calls_section_start_variants:
idx = text.find(variant)
if idx != -1 and (best == -1 or idx < best):
best = idx
# Fallback: if no section-level marker found, look for
# <|tool_call_begin|> directly.
if best == -1 and self._fallback_section_start:
idx = text.find(self._fallback_section_start)
if idx != -1:
best = idx
return best
def _find_section_start_end(self, text: str) -> tuple[int, int]:
"""Return (start_of_inner, end_of_inner) for the section region.
*start_of_inner* points just past the section-start marker.
*end_of_inner* is the index of the section-end marker, or -1
if the section is still open.
Falls back to <|tool_call_begin|> if no section-level marker
is found.
"""
for variant in self.tool_calls_section_start_variants:
idx = text.find(variant)
if idx != -1:
inner_start = idx + len(variant)
# Look for end marker
for end_variant in self.tool_calls_section_end_variants:
end_idx = text.find(end_variant, inner_start)
if end_idx != -1:
return inner_start, end_idx
return inner_start, -1
# Fallback: no section-level marker found. Look for
# <|tool_call_begin|> directly as the section start.
if self._fallback_section_start:
idx = text.find(self._fallback_section_start)
if idx != -1:
inner_start = idx + len(self._fallback_section_start)
# Look for <|tool_call_end|> as the section end
end_marker = self._fallback_section_start.replace("begin", "end")
end_idx = text.find(end_marker, inner_start)
if end_idx != -1:
return inner_start, end_idx
return inner_start, -1
return -1, -1
# ------------------------------------------------------------------
# Non-streaming extraction (logic preserved from original)
# ------------------------------------------------------------------
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
function_call_tuples = self.tool_call_regex.findall(model_output)
logger.debug("function_call_tuples: %s", function_call_tuples)
tool_calls: list[ToolCall] = []
for match in function_call_tuples:
function_id, function_args = match
function_name, full_id = self._parse_tool_id(function_id)
tool_calls.append(
ToolCall(
id=full_id,
type="function",
function=FunctionCall(
name=function_name,
arguments=function_args,
),
)
)
content_end = self._find_section_start(model_output)
content = model_output[:content_end] if content_end > 0 else None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# ------------------------------------------------------------------
# Streaming helpers — re-parse-and-diff
# ------------------------------------------------------------------
def _reset_streaming_state(self) -> None:
self._tool_call_ids.clear()
self.streamed_args_for_tool.clear()
self.prev_tool_call_arr.clear()
self.current_tool_id = -1
def _extract_tool_call_regions(
self, text: str
) -> list[tuple[str, bool]]:
"""Find all ``<|tool_call_begin|>`` … ``<|tool_call_end|>``
blocks inside the tool-calls section.
Returns a list of ``(inner_text, is_complete)`` tuples.
*inner_text* is everything between the tool-call open and close
tags (or end-of-available-text for the last partial block).
"""
results: list[tuple[str, bool]] = []
# Find the section region.
inner_start, inner_end = self._find_section_start_end(text)
if inner_start == -1:
return results
region = text[inner_start:inner_end] if inner_end != -1 else text[inner_start:]
pos = 0
while pos < len(region):
tc_start = region.find(self.tool_call_start_token, pos)
if tc_start == -1:
break
body_start = tc_start + len(self.tool_call_start_token)
tc_end = region.find(self.tool_call_end_token, body_start)
if tc_end != -1:
body = region[body_start:tc_end]
results.append((body, True))
pos = tc_end + len(self.tool_call_end_token)
else:
# Incomplete — still being generated.
body = region[body_start:]
overlap = partial_tag_overlap(body, self.tool_call_end_token)
if overlap:
body = body[:-overlap]
results.append((body, False))
break
return results
def _parse_tool_call_body(
self, body: str, is_complete: bool
) -> tuple[str | None, str | None, str]:
"""Parse a tool-call body into (func_name, tool_id, args_so_far).
The body looks like::
functions.get_weather:0
<|tool_call_argument_begin|>
{"location": "杭州"}
Returns ``(None, None, "")`` if the body doesn't contain enough
information yet (e.g. the tool ID is still arriving).
"""
# Extract tool ID (everything before <|tool_call_argument_begin|>
# or end of string).
arg_begin_idx = body.find(self.tool_call_arg_begin)
if arg_begin_idx != -1:
id_portion = body[:arg_begin_idx]
args_portion = body[arg_begin_idx + len(self.tool_call_arg_begin):]
else:
id_portion = body
args_portion = ""
# Try to extract the tool ID.
id_match = self.tool_id_regex.match(id_portion)
if not id_match:
# Not enough tokens yet to identify the tool.
return None, None, ""
raw_id = id_match.group("tool_id")
func_name, full_id = self._parse_tool_id(raw_id)
# Build args string.
args = args_portion.strip()
if is_complete:
# For a complete block, args is the final JSON.
return func_name, full_id, args
else:
# For a partial block, strip any trailing partial-tag overlap
# against tool_call_end (already done in caller), but also
# check for partial overlap against tool_call_argument_begin
# in case it hasn't fully arrived yet.
if arg_begin_idx == -1:
# No argument section yet.
overlap = partial_tag_overlap(
id_portion, self.tool_call_arg_begin
)
if overlap:
# The tag is still arriving — we have the name but
# no args yet.
pass
return func_name, full_id, ""
return func_name, full_id, args
def _compute_args_diff(self, index: int, args_so_far: str) -> str | None:
"""Return only the characters in *args_so_far* that haven't been
sent yet, or ``None`` if there's nothing new."""
prev = self.streamed_args_for_tool[index]
if not args_so_far or len(args_so_far) <= len(prev):
return None
diff = args_so_far[len(prev):]
self.streamed_args_for_tool[index] = args_so_far
self.prev_tool_call_arr[index]["arguments"] = args_so_far
return diff
def _ensure_tool_state_for(self, index: int) -> None:
"""Grow the streaming-state arrays so *index* is valid."""
while len(self._tool_call_ids) <= index:
self._tool_call_ids.append("")
while len(self.streamed_args_for_tool) <= index:
self.streamed_args_for_tool.append("")
while len(self.prev_tool_call_arr) <= index:
self.prev_tool_call_arr.append({})
# ------------------------------------------------------------------
# Main streaming entry point
# ------------------------------------------------------------------
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming output.
Hybrid approach:
- **Content forwarding** uses ``delta_text`` (same as the
original parser) so we never re-emit text that the reasoning
parser already handled.
- **Tool call detection** re-parses ``current_text`` on every
call (the re-parse-and-diff approach) so it's agnostic to
how many tokens arrived per step — robust against MTP.
"""
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# First chunk of a new stream — reset state.
if not previous_text:
self._reset_streaming_state()
# If tools aren't enabled, just forward content.
if not self._tools_enabled(request):
return DeltaMessage(content=delta_text) if delta_text else None
# ── Determine section state from full text (MTP-safe) ──
inner_start, inner_end = self._find_section_start_end(current_text)
in_open_section = inner_start != -1 and inner_end == -1
# Was the section already open in previous_text?
prev_inner_start, _ = self._find_section_start_end(previous_text)
section_existed_before = prev_inner_start != -1
# ── Re-parse tool calls from current_text (MTP-safe) ──
regions = self._extract_tool_call_regions(current_text)
tool_call_deltas: list[DeltaToolCall] = []
for i, (body, is_complete) in enumerate(regions):
self._ensure_tool_state_for(i)
func_name, tool_id, args_so_far = self._parse_tool_call_body(
body, is_complete
)
if func_name is None:
# Not enough data to identify the tool yet.
break
# Emit the tool name (once per tool call).
if "name" not in self.prev_tool_call_arr[i]:
self.prev_tool_call_arr[i]["name"] = func_name
self._tool_call_ids[i] = tool_id or ""
tool_call_deltas.append(
DeltaToolCall(
index=i,
id=tool_id,
type="function",
function=DeltaFunctionCall(
name=func_name,
arguments="",
),
)
)
# Diff the arguments and emit any new characters.
diff = self._compute_args_diff(i, args_so_far)
if diff:
tool_call_deltas.append(
DeltaToolCall(
index=i,
function=DeltaFunctionCall(arguments=diff),
)
)
if regions:
self.current_tool_id = len(regions) - 1
# ── Emit results ──
# Case 1: We have tool call updates — emit them.
if tool_call_deltas:
return DeltaMessage(tool_calls=tool_call_deltas)
# Case 2: No tool section has started yet — forward delta_text
# as content. The reasoning parser handles the reasoning/content
# split; we just pass through whatever delta the serving layer
# gave us.
if inner_start == -1:
return DeltaMessage(content=delta_text) if delta_text else None
# Case 3: The section just appeared in this delta. Extract any
# content that came before the section marker in this delta
# (e.g. "Let me check.<|tool_calls_section_begin|>").
if not section_existed_before:
section_start_in_text = self._find_section_start(current_text)
pre_section = current_text[len(previous_text):section_start_in_text]
if pre_section.strip():
return DeltaMessage(content=pre_section)
# No real content before the section — return None instead of
# an empty-string delta. Empty content deltas confuse clients
# that distinguish content=null from content="".
return None
# Case 4: Inside an open tool section but tool calls aren't
# parseable yet — return None. The serving layer will emit
# its own keep-alive if needed; we should not emit empty-string
# content deltas that pollute the response.
if in_open_section:
return None
# Case 5: Section is closed and we're past it — forward any
# new content that appeared after the section end marker.
if inner_end != -1:
for variant in self.tool_calls_section_end_variants:
end_marker_pos = current_text.find(variant, inner_start)
if end_marker_pos != -1:
after_section = current_text[
end_marker_pos + len(variant):
]
# Only emit what's new (not previously seen)
prev_after_len = 0
prev_end_pos = previous_text.find(variant)
if prev_end_pos != -1:
prev_after_len = len(
previous_text[prev_end_pos + len(variant):]
)
new_after = after_section[prev_after_len:]
if new_after:
return DeltaMessage(content=new_after)
break
return None
return None