add the tool call parser fixes for eagle decode

This commit is contained in:
2026-04-14 03:13:24 +00:00
parent 4de7496f5b
commit 9be82d3574
2 changed files with 571 additions and 1 deletions

View File

@@ -6,4 +6,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends unzip && rm -rf
ADD https://ewr1.vultrobjects.com/artifacts/models--nvidia--Kimi-K2.5-Thinking-Eagle3.zip /tmp/eagle3.zip
RUN unzip /tmp/eagle3.zip -d /opt/nvidia-Kimi-K2.5-Thinking-Eagle3 && \
rm /tmp/eagle3.zip && \
apt-get remove -y unzip && apt-get autoremove -y
apt-get remove -y unzip && apt-get autoremove -y
# Patch tool parser for MTP
COPY kimi_k2_tool_parser.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/kimi_k2_tool_parser.py

567
kimi_k2_tool_parser.py Normal file
View File

@@ -0,0 +1,567 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Kimi-K2 Tool Call Parser — re-parse-and-diff version.
Adapted from the GLM-4/DeepSeek-V3.2 streaming fix to make the
streaming path robust against multi-token deltas produced by MTP
speculative decoding.
Instead of counting start/end tokens to maintain an incremental state
machine, the streaming path re-parses the *entire* current_text on
every call, finds all <|tool_call_begin|> regions (complete and
in-progress), extracts the JSON arguments for each, and diffs against
what was previously sent.
Key changes vs. the upstream token-count parser:
1. No token-count state machine — the parser is stateless w.r.t.
how many tokens arrived per step.
2. _extract_content() uses partial_tag_overlap to safely handle
section-start tags split across chunk boundaries.
3. _extract_tool_call_regions() finds both complete and incomplete
tool-call blocks, enabling argument streaming.
4. _compute_args_diff() emits only newly-added characters.
5. Handles singular/plural section marker variants.
Drop-in replacement: same class name, same interface.
Example tool call format::
<|tool_calls_section_begin|>
<|tool_call_begin|>
functions.get_weather:0
<|tool_call_argument_begin|>
{"location": "杭州", "date": "2024-01-16"}
<|tool_call_end|>
<|tool_call_begin|>
functions.get_time:1
<|tool_call_argument_begin|>
{"timezone": "Asia/Shanghai"}
<|tool_call_end|>
<|tool_calls_section_end|>
"""
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
logger = init_logger(__name__)
# ---------------------------------------------------------------------------
# Utility — inlined to avoid import issues across vLLM versions
# ---------------------------------------------------------------------------
def partial_tag_overlap(text: str, tag: str) -> int:
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
E.g. text ending in ``"<|tool_call"`` returns 11 when tag is
``"<|tool_call_begin|>"``. Returns 0 when there is no overlap.
"""
max_check = min(len(tag) - 1, len(text))
for k in range(max_check, 0, -1):
if text.endswith(tag[:k]):
return k
return 0
class KimiK2ToolParser(ToolParser):
"""Re-parse-and-diff tool parser for Kimi-K2 format.
On every streaming call the parser re-parses ``current_text`` to
find tool-call regions, extracts the JSON arguments for each, and
diffs against what was previously sent. This is robust against
multi-token deltas from MTP / EAGLE speculative decoding.
"""
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# ----- Tag constants -----
# Section wrappers (support singular & plural variants)
self.tool_calls_section_start_variants: list[str] = [
"<|tool_calls_section_begin|>",
"<|tool_call_section_begin|>",
]
self.tool_calls_section_end_variants: list[str] = [
"<|tool_calls_section_end|>",
"<|tool_call_section_end|>",
]
# Primary variant for ToolParser base class / adjust_request
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
# Individual tool-call markers
self.tool_call_start_token: str = "<|tool_call_begin|>"
self.tool_call_end_token: str = "<|tool_call_end|>"
self.tool_call_arg_begin: str = "<|tool_call_argument_begin|>"
# ----- Compiled regexes -----
# Complete tool call block.
self.tool_call_regex = re.compile(
r"<\|tool_call_begin\|>\s*"
r"(?P<tool_call_id>[^<]+:\d+)\s*"
r"<\|tool_call_argument_begin\|>\s*"
r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
r"<\|tool_call_end\|>",
re.DOTALL,
)
# For extracting tool ID from the start of a tool-call region.
self.tool_id_regex = re.compile(
r"\s*(?P<tool_id>[^\s<]+:\d+)\s*", re.DOTALL
)
# ----- Streaming state (reset per request) -----
self._sent_content_idx: int = 0
self._tool_call_ids: list[str] = []
self.streamed_args_for_tool: list[str] = []
self.prev_tool_call_arr: list[dict[str, Any]] = []
self.current_tool_id: int = -1
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
# Validate that the primary section tokens exist in vocab.
self.tool_calls_start_token_id = self.vocab.get(
self.tool_calls_start_token
)
self.tool_calls_end_token_id = self.vocab.get(
self.tool_calls_end_token
)
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token
)
self.tool_call_end_token_id = self.vocab.get(
self.tool_call_end_token
)
if (
self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None
):
raise RuntimeError(
"Kimi-K2 Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
logger.debug(
"Successfully initialized %s", self.__class__.__name__
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _tools_enabled(request: ChatCompletionRequest) -> bool:
try:
tools = getattr(request, "tools", None)
tool_choice = getattr(request, "tool_choice", None)
return bool(tools) and tool_choice != "none"
except Exception:
logger.exception("Failed to determine if tools are enabled.")
return False
@staticmethod
def _parse_tool_id(raw_id: str) -> tuple[str, str]:
"""Parse ``'functions.get_weather:0'`` → ``('get_weather', 'functions.get_weather:0')``."""
raw_id = raw_id.strip()
function_name = raw_id.split(":")[0].split(".")[-1]
return function_name, raw_id
def _find_section_start(self, text: str) -> int:
"""Return the index of the first section-start marker, or -1."""
best = -1
for variant in self.tool_calls_section_start_variants:
idx = text.find(variant)
if idx != -1 and (best == -1 or idx < best):
best = idx
return best
def _find_section_start_end(self, text: str) -> tuple[int, int]:
"""Return (start_of_inner, end_of_inner) for the section region.
*start_of_inner* points just past the section-start marker.
*end_of_inner* is the index of the section-end marker, or -1
if the section is still open.
"""
for variant in self.tool_calls_section_start_variants:
idx = text.find(variant)
if idx != -1:
inner_start = idx + len(variant)
# Look for end marker
for end_variant in self.tool_calls_section_end_variants:
end_idx = text.find(end_variant, inner_start)
if end_idx != -1:
return inner_start, end_idx
return inner_start, -1
return -1, -1
# ------------------------------------------------------------------
# Non-streaming extraction (logic preserved from original)
# ------------------------------------------------------------------
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
function_call_tuples = self.tool_call_regex.findall(model_output)
logger.debug("function_call_tuples: %s", function_call_tuples)
tool_calls: list[ToolCall] = []
for match in function_call_tuples:
function_id, function_args = match
function_name, full_id = self._parse_tool_id(function_id)
tool_calls.append(
ToolCall(
id=full_id,
type="function",
function=FunctionCall(
name=function_name,
arguments=function_args,
),
)
)
content_end = self._find_section_start(model_output)
content = model_output[:content_end] if content_end > 0 else None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# ------------------------------------------------------------------
# Streaming helpers — re-parse-and-diff
# ------------------------------------------------------------------
def _reset_streaming_state(self) -> None:
self._sent_content_idx = 0
self._tool_call_ids.clear()
self.streamed_args_for_tool.clear()
self.prev_tool_call_arr.clear()
self.current_tool_id = -1
def _get_earliest_section_start_tag(self) -> str:
"""Return the shortest section-start variant (used for overlap
checking — the shortest tag has the smallest possible overlap,
so we check against the longest to be safe)."""
return max(
self.tool_calls_section_start_variants, key=len
)
def _extract_content(self, current_text: str) -> str | None:
"""Return any non-tool-section text that hasn't been sent yet.
Walks *current_text* from ``_sent_content_idx``, collecting
text outside ``<|tool_calls_section_begin|>`` …
``<|tool_calls_section_end|>`` regions. Uses
``partial_tag_overlap`` to avoid emitting bytes that might be
the start of a section tag split across chunks.
"""
content_segments: list[str] = []
pos = self._sent_content_idx
overlap_tag = self._get_earliest_section_start_tag()
while pos < len(current_text):
# Find next section-start marker from pos.
best_start = -1
best_variant_len = 0
for variant in self.tool_calls_section_start_variants:
idx = current_text.find(variant, pos)
if idx != -1 and (best_start == -1 or idx < best_start):
best_start = idx
best_variant_len = len(variant)
if best_start == -1:
# No more section regions — emit tail minus overlap.
tail = current_text[pos:]
overlap = partial_tag_overlap(tail, overlap_tag)
sendable = tail[: len(tail) - overlap] if overlap else tail
if sendable:
content_segments.append(sendable)
pos = len(current_text) - overlap
break
# Text before the section start is content.
if best_start > pos:
content_segments.append(current_text[pos:best_start])
# Skip past the section region.
inner_start = best_start + best_variant_len
# Find matching section-end.
best_end = -1
best_end_variant_len = 0
for variant in self.tool_calls_section_end_variants:
idx = current_text.find(variant, inner_start)
if idx != -1 and (best_end == -1 or idx < best_end):
best_end = idx
best_end_variant_len = len(variant)
if best_end != -1:
pos = best_end + best_end_variant_len
else:
# Section still open — park cursor, stop.
pos = best_start
break
if content_segments:
self._sent_content_idx = pos
return "".join(content_segments)
if pos > self._sent_content_idx:
self._sent_content_idx = pos
return None
def _extract_tool_call_regions(
self, text: str
) -> list[tuple[str, bool]]:
"""Find all ``<|tool_call_begin|>`` … ``<|tool_call_end|>``
blocks inside the tool-calls section.
Returns a list of ``(inner_text, is_complete)`` tuples.
*inner_text* is everything between the tool-call open and close
tags (or end-of-available-text for the last partial block).
"""
results: list[tuple[str, bool]] = []
# Find the section region.
inner_start, inner_end = self._find_section_start_end(text)
if inner_start == -1:
return results
region = text[inner_start:inner_end] if inner_end != -1 else text[inner_start:]
pos = 0
while pos < len(region):
tc_start = region.find(self.tool_call_start_token, pos)
if tc_start == -1:
break
body_start = tc_start + len(self.tool_call_start_token)
tc_end = region.find(self.tool_call_end_token, body_start)
if tc_end != -1:
body = region[body_start:tc_end]
results.append((body, True))
pos = tc_end + len(self.tool_call_end_token)
else:
# Incomplete — still being generated.
body = region[body_start:]
overlap = partial_tag_overlap(body, self.tool_call_end_token)
if overlap:
body = body[:-overlap]
results.append((body, False))
break
return results
def _parse_tool_call_body(
self, body: str, is_complete: bool
) -> tuple[str | None, str | None, str]:
"""Parse a tool-call body into (func_name, tool_id, args_so_far).
The body looks like::
functions.get_weather:0
<|tool_call_argument_begin|>
{"location": "杭州"}
Returns ``(None, None, "")`` if the body doesn't contain enough
information yet (e.g. the tool ID is still arriving).
"""
# Extract tool ID (everything before <|tool_call_argument_begin|>
# or end of string).
arg_begin_idx = body.find(self.tool_call_arg_begin)
if arg_begin_idx != -1:
id_portion = body[:arg_begin_idx]
args_portion = body[arg_begin_idx + len(self.tool_call_arg_begin):]
else:
id_portion = body
args_portion = ""
# Try to extract the tool ID.
id_match = self.tool_id_regex.match(id_portion)
if not id_match:
# Not enough tokens yet to identify the tool.
return None, None, ""
raw_id = id_match.group("tool_id")
func_name, full_id = self._parse_tool_id(raw_id)
# Build args string.
args = args_portion.strip()
if is_complete:
# For a complete block, args is the final JSON.
return func_name, full_id, args
else:
# For a partial block, strip any trailing partial-tag overlap
# against tool_call_end (already done in caller), but also
# check for partial overlap against tool_call_argument_begin
# in case it hasn't fully arrived yet.
if arg_begin_idx == -1:
# No argument section yet.
overlap = partial_tag_overlap(
id_portion, self.tool_call_arg_begin
)
if overlap:
# The tag is still arriving — we have the name but
# no args yet.
pass
return func_name, full_id, ""
return func_name, full_id, args
def _compute_args_diff(self, index: int, args_so_far: str) -> str | None:
"""Return only the characters in *args_so_far* that haven't been
sent yet, or ``None`` if there's nothing new."""
prev = self.streamed_args_for_tool[index]
if not args_so_far or len(args_so_far) <= len(prev):
return None
diff = args_so_far[len(prev):]
self.streamed_args_for_tool[index] = args_so_far
self.prev_tool_call_arr[index]["arguments"] = args_so_far
return diff
def _ensure_tool_state_for(self, index: int) -> None:
"""Grow the streaming-state arrays so *index* is valid."""
while len(self._tool_call_ids) <= index:
self._tool_call_ids.append("")
while len(self.streamed_args_for_tool) <= index:
self.streamed_args_for_tool.append("")
while len(self.prev_tool_call_arr) <= index:
self.prev_tool_call_arr.append({})
# ------------------------------------------------------------------
# Main streaming entry point
# ------------------------------------------------------------------
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming output using re-parse-and-diff.
On every call we:
1. Re-scan *current_text* for content outside tool-call sections.
2. Find all ``<|tool_call_begin|>`` regions (complete + partial).
3. Parse each region for tool ID and arguments.
4. Diff arguments against previous state, emit deltas.
Because the entire text is re-parsed each time, the result is
correct regardless of how many tokens arrived in this step.
"""
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# First chunk of a new stream — reset state.
if not previous_text:
self._reset_streaming_state()
# If tools aren't enabled, just forward content.
if not self._tools_enabled(request):
return DeltaMessage(content=delta_text) if delta_text else None
# 1. Extract any content outside tool-call sections.
content = self._extract_content(current_text)
# 2. Find all tool-call regions.
regions = self._extract_tool_call_regions(current_text)
tool_call_deltas: list[DeltaToolCall] = []
for i, (body, is_complete) in enumerate(regions):
self._ensure_tool_state_for(i)
func_name, tool_id, args_so_far = self._parse_tool_call_body(
body, is_complete
)
if func_name is None:
# Not enough data to identify the tool yet.
break
# Emit the tool name (once per tool call).
if "name" not in self.prev_tool_call_arr[i]:
self.prev_tool_call_arr[i]["name"] = func_name
self._tool_call_ids[i] = tool_id or ""
tool_call_deltas.append(
DeltaToolCall(
index=i,
id=tool_id,
type="function",
function=DeltaFunctionCall(
name=func_name,
arguments="",
),
)
)
# Diff the arguments and emit any new characters.
diff = self._compute_args_diff(i, args_so_far)
if diff:
tool_call_deltas.append(
DeltaToolCall(
index=i,
function=DeltaFunctionCall(arguments=diff),
)
)
if regions:
self.current_tool_id = len(regions) - 1
# 3. Return a delta if we have content or tool-call updates.
if content or tool_call_deltas:
return DeltaMessage(
content=content,
tool_calls=tool_call_deltas if tool_call_deltas else None,
)
# Empty delta with token ids means EOS or closing tag — return
# non-None so the serving framework can finalize finish_reason.
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
return DeltaMessage(content="")
return None