# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Kimi-K2 Tool Call Parser — re-parse-and-diff version. Adapted from the GLM-4/DeepSeek-V3.2 streaming fix to make the streaming path robust against multi-token deltas produced by MTP speculative decoding. Instead of counting start/end tokens to maintain an incremental state machine, the streaming path re-parses the *entire* current_text on every call, finds all <|tool_call_begin|> regions (complete and in-progress), extracts the JSON arguments for each, and diffs against what was previously sent. Key changes vs. the upstream token-count parser: 1. No token-count state machine — the parser is stateless w.r.t. how many tokens arrived per step. 2. Content forwarding uses delta_text (not re-parsed current_text) so reasoning text is never re-emitted as content. 3. _extract_tool_call_regions() finds both complete and incomplete tool-call blocks, enabling argument streaming. 4. _compute_args_diff() emits only newly-added characters. 5. Handles singular/plural section marker variants. 6. Returns empty deltas inside open sections to keep the stream alive while tool call tokens are still arriving. Drop-in replacement: same class name, same interface. Example tool call format:: <|tool_calls_section_begin|> <|tool_call_begin|> functions.get_weather:0 <|tool_call_argument_begin|> {"location": "杭州", "date": "2024-01-16"} <|tool_call_end|> <|tool_call_begin|> functions.get_time:1 <|tool_call_argument_begin|> {"timezone": "Asia/Shanghai"} <|tool_call_end|> <|tool_calls_section_end|> """ from collections.abc import Sequence from typing import Any import regex as re from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) logger = init_logger(__name__) # --------------------------------------------------------------------------- # Utility — inlined to avoid import issues across vLLM versions # --------------------------------------------------------------------------- def partial_tag_overlap(text: str, tag: str) -> int: """Length of the longest prefix of *tag* that matches a suffix of *text*. E.g. text ending in ``"<|tool_call"`` returns 11 when tag is ``"<|tool_call_begin|>"``. Returns 0 when there is no overlap. """ max_check = min(len(tag) - 1, len(text)) for k in range(max_check, 0, -1): if text.endswith(tag[:k]): return k return 0 class KimiK2ToolParser(ToolParser): """Re-parse-and-diff tool parser for Kimi-K2 format. On every streaming call the parser re-parses ``current_text`` to find tool-call regions, extracts the JSON arguments for each, and diffs against what was previously sent. This is robust against multi-token deltas from MTP / EAGLE speculative decoding. """ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) # ----- Tag constants ----- # Section wrappers (support singular & plural variants) self.tool_calls_section_start_variants: list[str] = [ "<|tool_calls_section_begin|>", "<|tool_call_section_begin|>", ] self.tool_calls_section_end_variants: list[str] = [ "<|tool_calls_section_end|>", "<|tool_call_section_end|>", ] # Primary variant for ToolParser base class / adjust_request self.tool_calls_start_token: str = "<|tool_calls_section_begin|>" self.tool_calls_end_token: str = "<|tool_calls_section_end|>" # Individual tool-call markers self.tool_call_start_token: str = "<|tool_call_begin|>" self.tool_call_end_token: str = "<|tool_call_end|>" self.tool_call_arg_begin: str = "<|tool_call_argument_begin|>" # ----- Compiled regexes ----- # Complete tool call block. self.tool_call_regex = re.compile( r"<\|tool_call_begin\|>\s*" r"(?P[^<]+:\d+)\s*" r"<\|tool_call_argument_begin\|>\s*" r"(?P(?:(?!<\|tool_call_begin\|>).)*?)\s*" r"<\|tool_call_end\|>", re.DOTALL, ) # For extracting tool ID from the start of a tool-call region. self.tool_id_regex = re.compile( r"\s*(?P[^\s<]+:\d+)\s*", re.DOTALL ) # ----- Streaming state (reset per request) ----- self._tool_call_ids: list[str] = [] self.streamed_args_for_tool: list[str] = [] self.prev_tool_call_arr: list[dict[str, Any]] = [] self.current_tool_id: int = -1 if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction." ) # Validate that the primary section tokens exist in vocab. self.tool_calls_start_token_id = self.vocab.get( self.tool_calls_start_token ) self.tool_calls_end_token_id = self.vocab.get( self.tool_calls_end_token ) self.tool_call_start_token_id = self.vocab.get( self.tool_call_start_token ) self.tool_call_end_token_id = self.vocab.get( self.tool_call_end_token ) if ( self.tool_calls_start_token_id is None or self.tool_calls_end_token_id is None ): raise RuntimeError( "Kimi-K2 Tool parser could not locate tool call start/end " "tokens in the tokenizer!" ) logger.debug( "Successfully initialized %s", self.__class__.__name__ ) # ------------------------------------------------------------------ # Request adjustment # ------------------------------------------------------------------ def adjust_request( self, request: ChatCompletionRequest | ResponsesRequest ) -> ChatCompletionRequest | ResponsesRequest: request = super().adjust_request(request) if request.tools and request.tool_choice != "none": # Ensure tool-call tokens (<|tool_calls_section_begin|>, # <|tool_call_begin|>, etc.) are not stripped during decoding. request.skip_special_tokens = False return request # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @staticmethod def _tools_enabled(request: ChatCompletionRequest) -> bool: try: tools = getattr(request, "tools", None) tool_choice = getattr(request, "tool_choice", None) return bool(tools) and tool_choice != "none" except Exception: logger.exception("Failed to determine if tools are enabled.") return False @staticmethod def _parse_tool_id(raw_id: str) -> tuple[str, str]: """Parse ``'functions.get_weather:0'`` → ``('get_weather', 'functions.get_weather:0')``.""" raw_id = raw_id.strip() function_name = raw_id.split(":")[0].split(".")[-1] return function_name, raw_id def _find_section_start(self, text: str) -> int: """Return the index of the first section-start marker, or -1.""" best = -1 for variant in self.tool_calls_section_start_variants: idx = text.find(variant) if idx != -1 and (best == -1 or idx < best): best = idx return best def _find_section_start_end(self, text: str) -> tuple[int, int]: """Return (start_of_inner, end_of_inner) for the section region. *start_of_inner* points just past the section-start marker. *end_of_inner* is the index of the section-end marker, or -1 if the section is still open. """ for variant in self.tool_calls_section_start_variants: idx = text.find(variant) if idx != -1: inner_start = idx + len(variant) # Look for end marker for end_variant in self.tool_calls_section_end_variants: end_idx = text.find(end_variant, inner_start) if end_idx != -1: return inner_start, end_idx return inner_start, -1 return -1, -1 # ------------------------------------------------------------------ # Non-streaming extraction (logic preserved from original) # ------------------------------------------------------------------ def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: if self.tool_calls_start_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) try: function_call_tuples = self.tool_call_regex.findall(model_output) logger.debug("function_call_tuples: %s", function_call_tuples) tool_calls: list[ToolCall] = [] for match in function_call_tuples: function_id, function_args = match function_name, full_id = self._parse_tool_id(function_id) tool_calls.append( ToolCall( id=full_id, type="function", function=FunctionCall( name=function_name, arguments=function_args, ), ) ) content_end = self._find_section_start(model_output) content = model_output[:content_end] if content_end > 0 else None return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content if content else None, ) except Exception: logger.exception("Error in extracting tool call from response.") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) # ------------------------------------------------------------------ # Streaming helpers — re-parse-and-diff # ------------------------------------------------------------------ def _reset_streaming_state(self) -> None: self._tool_call_ids.clear() self.streamed_args_for_tool.clear() self.prev_tool_call_arr.clear() self.current_tool_id = -1 def _extract_tool_call_regions( self, text: str ) -> list[tuple[str, bool]]: """Find all ``<|tool_call_begin|>`` … ``<|tool_call_end|>`` blocks inside the tool-calls section. Returns a list of ``(inner_text, is_complete)`` tuples. *inner_text* is everything between the tool-call open and close tags (or end-of-available-text for the last partial block). """ results: list[tuple[str, bool]] = [] # Find the section region. inner_start, inner_end = self._find_section_start_end(text) if inner_start == -1: return results region = text[inner_start:inner_end] if inner_end != -1 else text[inner_start:] pos = 0 while pos < len(region): tc_start = region.find(self.tool_call_start_token, pos) if tc_start == -1: break body_start = tc_start + len(self.tool_call_start_token) tc_end = region.find(self.tool_call_end_token, body_start) if tc_end != -1: body = region[body_start:tc_end] results.append((body, True)) pos = tc_end + len(self.tool_call_end_token) else: # Incomplete — still being generated. body = region[body_start:] overlap = partial_tag_overlap(body, self.tool_call_end_token) if overlap: body = body[:-overlap] results.append((body, False)) break return results def _parse_tool_call_body( self, body: str, is_complete: bool ) -> tuple[str | None, str | None, str]: """Parse a tool-call body into (func_name, tool_id, args_so_far). The body looks like:: functions.get_weather:0 <|tool_call_argument_begin|> {"location": "杭州"} Returns ``(None, None, "")`` if the body doesn't contain enough information yet (e.g. the tool ID is still arriving). """ # Extract tool ID (everything before <|tool_call_argument_begin|> # or end of string). arg_begin_idx = body.find(self.tool_call_arg_begin) if arg_begin_idx != -1: id_portion = body[:arg_begin_idx] args_portion = body[arg_begin_idx + len(self.tool_call_arg_begin):] else: id_portion = body args_portion = "" # Try to extract the tool ID. id_match = self.tool_id_regex.match(id_portion) if not id_match: # Not enough tokens yet to identify the tool. return None, None, "" raw_id = id_match.group("tool_id") func_name, full_id = self._parse_tool_id(raw_id) # Build args string. args = args_portion.strip() if is_complete: # For a complete block, args is the final JSON. return func_name, full_id, args else: # For a partial block, strip any trailing partial-tag overlap # against tool_call_end (already done in caller), but also # check for partial overlap against tool_call_argument_begin # in case it hasn't fully arrived yet. if arg_begin_idx == -1: # No argument section yet. overlap = partial_tag_overlap( id_portion, self.tool_call_arg_begin ) if overlap: # The tag is still arriving — we have the name but # no args yet. pass return func_name, full_id, "" return func_name, full_id, args def _compute_args_diff(self, index: int, args_so_far: str) -> str | None: """Return only the characters in *args_so_far* that haven't been sent yet, or ``None`` if there's nothing new.""" prev = self.streamed_args_for_tool[index] if not args_so_far or len(args_so_far) <= len(prev): return None diff = args_so_far[len(prev):] self.streamed_args_for_tool[index] = args_so_far self.prev_tool_call_arr[index]["arguments"] = args_so_far return diff def _ensure_tool_state_for(self, index: int) -> None: """Grow the streaming-state arrays so *index* is valid.""" while len(self._tool_call_ids) <= index: self._tool_call_ids.append("") while len(self.streamed_args_for_tool) <= index: self.streamed_args_for_tool.append("") while len(self.prev_tool_call_arr) <= index: self.prev_tool_call_arr.append({}) # ------------------------------------------------------------------ # Main streaming entry point # ------------------------------------------------------------------ def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: """Extract tool calls from streaming output. Hybrid approach: - **Content forwarding** uses ``delta_text`` (same as the original parser) so we never re-emit text that the reasoning parser already handled. - **Tool call detection** re-parses ``current_text`` on every call (the re-parse-and-diff approach) so it's agnostic to how many tokens arrived per step — robust against MTP. """ logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # First chunk of a new stream — reset state. if not previous_text: self._reset_streaming_state() # If tools aren't enabled, just forward content. if not self._tools_enabled(request): return DeltaMessage(content=delta_text) if delta_text else None # ── Determine section state from full text (MTP-safe) ── inner_start, inner_end = self._find_section_start_end(current_text) in_open_section = inner_start != -1 and inner_end == -1 # Was the section already open in previous_text? prev_inner_start, _ = self._find_section_start_end(previous_text) section_existed_before = prev_inner_start != -1 # ── Re-parse tool calls from current_text (MTP-safe) ── regions = self._extract_tool_call_regions(current_text) tool_call_deltas: list[DeltaToolCall] = [] for i, (body, is_complete) in enumerate(regions): self._ensure_tool_state_for(i) func_name, tool_id, args_so_far = self._parse_tool_call_body( body, is_complete ) if func_name is None: # Not enough data to identify the tool yet. break # Emit the tool name (once per tool call). if "name" not in self.prev_tool_call_arr[i]: self.prev_tool_call_arr[i]["name"] = func_name self._tool_call_ids[i] = tool_id or "" tool_call_deltas.append( DeltaToolCall( index=i, id=tool_id, type="function", function=DeltaFunctionCall( name=func_name, arguments="", ), ) ) # Diff the arguments and emit any new characters. diff = self._compute_args_diff(i, args_so_far) if diff: tool_call_deltas.append( DeltaToolCall( index=i, function=DeltaFunctionCall(arguments=diff), ) ) if regions: self.current_tool_id = len(regions) - 1 # ── Emit results ── # Case 1: We have tool call updates — emit them. if tool_call_deltas: return DeltaMessage(tool_calls=tool_call_deltas) # Case 2: No tool section has started yet — forward delta_text # as content. The reasoning parser handles the reasoning/content # split; we just pass through whatever delta the serving layer # gave us. if inner_start == -1: return DeltaMessage(content=delta_text) if delta_text else None # Case 3: The section just appeared in this delta. Extract any # content that came before the section marker in this delta # (e.g. "Let me check.<|tool_calls_section_begin|>"). if not section_existed_before: section_start_in_text = self._find_section_start(current_text) pre_section = current_text[len(previous_text):section_start_in_text] if pre_section.strip(): return DeltaMessage(content=pre_section) return DeltaMessage(content="") # Case 4: Inside an open tool section but tool calls aren't # parseable yet — emit empty delta to keep the stream alive. if in_open_section: return DeltaMessage(content="") # Case 5: Section is closed and we're past it — forward any # new content that appeared after the section end marker. if inner_end != -1: for variant in self.tool_calls_section_end_variants: end_marker_pos = current_text.find(variant, inner_start) if end_marker_pos != -1: after_section = current_text[ end_marker_pos + len(variant): ] # Only emit what's new (not previously seen) prev_after_len = 0 prev_end_pos = previous_text.find(variant) if prev_end_pos != -1: prev_after_len = len( previous_text[prev_end_pos + len(variant):] ) new_after = after_section[prev_after_len:] if new_after: return DeltaMessage(content=new_after) break return DeltaMessage(content="") if delta_text else None return None