# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Kimi-K2 Tool Call Parser — re-parse-and-diff version. Adapted from the GLM-4/DeepSeek-V3.2 streaming fix to make the streaming path robust against multi-token deltas produced by MTP speculative decoding. Instead of counting start/end tokens to maintain an incremental state machine, the streaming path re-parses the *entire* current_text on every call, finds all <|tool_call_begin|> regions (complete and in-progress), extracts the JSON arguments for each, and diffs against what was previously sent. Key changes vs. the upstream token-count parser: 1. No token-count state machine — the parser is stateless w.r.t. how many tokens arrived per step. 2. _extract_content() uses partial_tag_overlap to safely handle section-start tags split across chunk boundaries. 3. _extract_tool_call_regions() finds both complete and incomplete tool-call blocks, enabling argument streaming. 4. _compute_args_diff() emits only newly-added characters. 5. Handles singular/plural section marker variants. Drop-in replacement: same class name, same interface. Example tool call format:: <|tool_calls_section_begin|> <|tool_call_begin|> functions.get_weather:0 <|tool_call_argument_begin|> {"location": "杭州", "date": "2024-01-16"} <|tool_call_end|> <|tool_call_begin|> functions.get_time:1 <|tool_call_argument_begin|> {"timezone": "Asia/Shanghai"} <|tool_call_end|> <|tool_calls_section_end|> """ from collections.abc import Sequence from typing import Any import regex as re from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( Tool, ToolParser, ) logger = init_logger(__name__) # --------------------------------------------------------------------------- # Utility — inlined to avoid import issues across vLLM versions # --------------------------------------------------------------------------- def partial_tag_overlap(text: str, tag: str) -> int: """Length of the longest prefix of *tag* that matches a suffix of *text*. E.g. text ending in ``"<|tool_call"`` returns 11 when tag is ``"<|tool_call_begin|>"``. Returns 0 when there is no overlap. """ max_check = min(len(tag) - 1, len(text)) for k in range(max_check, 0, -1): if text.endswith(tag[:k]): return k return 0 class KimiK2ToolParser(ToolParser): """Re-parse-and-diff tool parser for Kimi-K2 format. On every streaming call the parser re-parses ``current_text`` to find tool-call regions, extracts the JSON arguments for each, and diffs against what was previously sent. This is robust against multi-token deltas from MTP / EAGLE speculative decoding. """ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) # ----- Tag constants ----- # Section wrappers (support singular & plural variants) self.tool_calls_section_start_variants: list[str] = [ "<|tool_calls_section_begin|>", "<|tool_call_section_begin|>", ] self.tool_calls_section_end_variants: list[str] = [ "<|tool_calls_section_end|>", "<|tool_call_section_end|>", ] # Primary variant for ToolParser base class / adjust_request self.tool_calls_start_token: str = "<|tool_calls_section_begin|>" self.tool_calls_end_token: str = "<|tool_calls_section_end|>" # Individual tool-call markers self.tool_call_start_token: str = "<|tool_call_begin|>" self.tool_call_end_token: str = "<|tool_call_end|>" self.tool_call_arg_begin: str = "<|tool_call_argument_begin|>" # ----- Compiled regexes ----- # Complete tool call block. self.tool_call_regex = re.compile( r"<\|tool_call_begin\|>\s*" r"(?P[^<]+:\d+)\s*" r"<\|tool_call_argument_begin\|>\s*" r"(?P(?:(?!<\|tool_call_begin\|>).)*?)\s*" r"<\|tool_call_end\|>", re.DOTALL, ) # For extracting tool ID from the start of a tool-call region. self.tool_id_regex = re.compile( r"\s*(?P[^\s<]+:\d+)\s*", re.DOTALL ) # ----- Streaming state (reset per request) ----- self._sent_content_idx: int = 0 self._tool_call_ids: list[str] = [] self.streamed_args_for_tool: list[str] = [] self.prev_tool_call_arr: list[dict[str, Any]] = [] self.current_tool_id: int = -1 if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction." ) # Validate that the primary section tokens exist in vocab. self.tool_calls_start_token_id = self.vocab.get( self.tool_calls_start_token ) self.tool_calls_end_token_id = self.vocab.get( self.tool_calls_end_token ) self.tool_call_start_token_id = self.vocab.get( self.tool_call_start_token ) self.tool_call_end_token_id = self.vocab.get( self.tool_call_end_token ) if ( self.tool_calls_start_token_id is None or self.tool_calls_end_token_id is None ): raise RuntimeError( "Kimi-K2 Tool parser could not locate tool call start/end " "tokens in the tokenizer!" ) logger.debug( "Successfully initialized %s", self.__class__.__name__ ) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @staticmethod def _tools_enabled(request: ChatCompletionRequest) -> bool: try: tools = getattr(request, "tools", None) tool_choice = getattr(request, "tool_choice", None) return bool(tools) and tool_choice != "none" except Exception: logger.exception("Failed to determine if tools are enabled.") return False @staticmethod def _parse_tool_id(raw_id: str) -> tuple[str, str]: """Parse ``'functions.get_weather:0'`` → ``('get_weather', 'functions.get_weather:0')``.""" raw_id = raw_id.strip() function_name = raw_id.split(":")[0].split(".")[-1] return function_name, raw_id def _find_section_start(self, text: str) -> int: """Return the index of the first section-start marker, or -1.""" best = -1 for variant in self.tool_calls_section_start_variants: idx = text.find(variant) if idx != -1 and (best == -1 or idx < best): best = idx return best def _find_section_start_end(self, text: str) -> tuple[int, int]: """Return (start_of_inner, end_of_inner) for the section region. *start_of_inner* points just past the section-start marker. *end_of_inner* is the index of the section-end marker, or -1 if the section is still open. """ for variant in self.tool_calls_section_start_variants: idx = text.find(variant) if idx != -1: inner_start = idx + len(variant) # Look for end marker for end_variant in self.tool_calls_section_end_variants: end_idx = text.find(end_variant, inner_start) if end_idx != -1: return inner_start, end_idx return inner_start, -1 return -1, -1 # ------------------------------------------------------------------ # Non-streaming extraction (logic preserved from original) # ------------------------------------------------------------------ def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: if self.tool_calls_start_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) try: function_call_tuples = self.tool_call_regex.findall(model_output) logger.debug("function_call_tuples: %s", function_call_tuples) tool_calls: list[ToolCall] = [] for match in function_call_tuples: function_id, function_args = match function_name, full_id = self._parse_tool_id(function_id) tool_calls.append( ToolCall( id=full_id, type="function", function=FunctionCall( name=function_name, arguments=function_args, ), ) ) content_end = self._find_section_start(model_output) content = model_output[:content_end] if content_end > 0 else None return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content if content else None, ) except Exception: logger.exception("Error in extracting tool call from response.") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) # ------------------------------------------------------------------ # Streaming helpers — re-parse-and-diff # ------------------------------------------------------------------ def _reset_streaming_state(self) -> None: self._sent_content_idx = 0 self._tool_call_ids.clear() self.streamed_args_for_tool.clear() self.prev_tool_call_arr.clear() self.current_tool_id = -1 def _get_earliest_section_start_tag(self) -> str: """Return the shortest section-start variant (used for overlap checking — the shortest tag has the smallest possible overlap, so we check against the longest to be safe).""" return max( self.tool_calls_section_start_variants, key=len ) def _extract_content(self, current_text: str) -> str | None: """Return any non-tool-section text that hasn't been sent yet. Walks *current_text* from ``_sent_content_idx``, collecting text outside ``<|tool_calls_section_begin|>`` … ``<|tool_calls_section_end|>`` regions. Uses ``partial_tag_overlap`` to avoid emitting bytes that might be the start of a section tag split across chunks. """ content_segments: list[str] = [] pos = self._sent_content_idx overlap_tag = self._get_earliest_section_start_tag() while pos < len(current_text): # Find next section-start marker from pos. best_start = -1 best_variant_len = 0 for variant in self.tool_calls_section_start_variants: idx = current_text.find(variant, pos) if idx != -1 and (best_start == -1 or idx < best_start): best_start = idx best_variant_len = len(variant) if best_start == -1: # No more section regions — emit tail minus overlap. tail = current_text[pos:] overlap = partial_tag_overlap(tail, overlap_tag) sendable = tail[: len(tail) - overlap] if overlap else tail if sendable: content_segments.append(sendable) pos = len(current_text) - overlap break # Text before the section start is content. if best_start > pos: content_segments.append(current_text[pos:best_start]) # Skip past the section region. inner_start = best_start + best_variant_len # Find matching section-end. best_end = -1 best_end_variant_len = 0 for variant in self.tool_calls_section_end_variants: idx = current_text.find(variant, inner_start) if idx != -1 and (best_end == -1 or idx < best_end): best_end = idx best_end_variant_len = len(variant) if best_end != -1: pos = best_end + best_end_variant_len else: # Section still open — park cursor, stop. pos = best_start break if content_segments: self._sent_content_idx = pos return "".join(content_segments) if pos > self._sent_content_idx: self._sent_content_idx = pos return None def _extract_tool_call_regions( self, text: str ) -> list[tuple[str, bool]]: """Find all ``<|tool_call_begin|>`` … ``<|tool_call_end|>`` blocks inside the tool-calls section. Returns a list of ``(inner_text, is_complete)`` tuples. *inner_text* is everything between the tool-call open and close tags (or end-of-available-text for the last partial block). """ results: list[tuple[str, bool]] = [] # Find the section region. inner_start, inner_end = self._find_section_start_end(text) if inner_start == -1: return results region = text[inner_start:inner_end] if inner_end != -1 else text[inner_start:] pos = 0 while pos < len(region): tc_start = region.find(self.tool_call_start_token, pos) if tc_start == -1: break body_start = tc_start + len(self.tool_call_start_token) tc_end = region.find(self.tool_call_end_token, body_start) if tc_end != -1: body = region[body_start:tc_end] results.append((body, True)) pos = tc_end + len(self.tool_call_end_token) else: # Incomplete — still being generated. body = region[body_start:] overlap = partial_tag_overlap(body, self.tool_call_end_token) if overlap: body = body[:-overlap] results.append((body, False)) break return results def _parse_tool_call_body( self, body: str, is_complete: bool ) -> tuple[str | None, str | None, str]: """Parse a tool-call body into (func_name, tool_id, args_so_far). The body looks like:: functions.get_weather:0 <|tool_call_argument_begin|> {"location": "杭州"} Returns ``(None, None, "")`` if the body doesn't contain enough information yet (e.g. the tool ID is still arriving). """ # Extract tool ID (everything before <|tool_call_argument_begin|> # or end of string). arg_begin_idx = body.find(self.tool_call_arg_begin) if arg_begin_idx != -1: id_portion = body[:arg_begin_idx] args_portion = body[arg_begin_idx + len(self.tool_call_arg_begin):] else: id_portion = body args_portion = "" # Try to extract the tool ID. id_match = self.tool_id_regex.match(id_portion) if not id_match: # Not enough tokens yet to identify the tool. return None, None, "" raw_id = id_match.group("tool_id") func_name, full_id = self._parse_tool_id(raw_id) # Build args string. args = args_portion.strip() if is_complete: # For a complete block, args is the final JSON. return func_name, full_id, args else: # For a partial block, strip any trailing partial-tag overlap # against tool_call_end (already done in caller), but also # check for partial overlap against tool_call_argument_begin # in case it hasn't fully arrived yet. if arg_begin_idx == -1: # No argument section yet. overlap = partial_tag_overlap( id_portion, self.tool_call_arg_begin ) if overlap: # The tag is still arriving — we have the name but # no args yet. pass return func_name, full_id, "" return func_name, full_id, args def _compute_args_diff(self, index: int, args_so_far: str) -> str | None: """Return only the characters in *args_so_far* that haven't been sent yet, or ``None`` if there's nothing new.""" prev = self.streamed_args_for_tool[index] if not args_so_far or len(args_so_far) <= len(prev): return None diff = args_so_far[len(prev):] self.streamed_args_for_tool[index] = args_so_far self.prev_tool_call_arr[index]["arguments"] = args_so_far return diff def _ensure_tool_state_for(self, index: int) -> None: """Grow the streaming-state arrays so *index* is valid.""" while len(self._tool_call_ids) <= index: self._tool_call_ids.append("") while len(self.streamed_args_for_tool) <= index: self.streamed_args_for_tool.append("") while len(self.prev_tool_call_arr) <= index: self.prev_tool_call_arr.append({}) # ------------------------------------------------------------------ # Main streaming entry point # ------------------------------------------------------------------ def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: """Extract tool calls from streaming output using re-parse-and-diff. On every call we: 1. Re-scan *current_text* for content outside tool-call sections. 2. Find all ``<|tool_call_begin|>`` regions (complete + partial). 3. Parse each region for tool ID and arguments. 4. Diff arguments against previous state, emit deltas. Because the entire text is re-parsed each time, the result is correct regardless of how many tokens arrived in this step. """ logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # First chunk of a new stream — reset state. if not previous_text: self._reset_streaming_state() # If tools aren't enabled, just forward content. if not self._tools_enabled(request): return DeltaMessage(content=delta_text) if delta_text else None # 1. Extract any content outside tool-call sections. content = self._extract_content(current_text) # 2. Find all tool-call regions. regions = self._extract_tool_call_regions(current_text) tool_call_deltas: list[DeltaToolCall] = [] for i, (body, is_complete) in enumerate(regions): self._ensure_tool_state_for(i) func_name, tool_id, args_so_far = self._parse_tool_call_body( body, is_complete ) if func_name is None: # Not enough data to identify the tool yet. break # Emit the tool name (once per tool call). if "name" not in self.prev_tool_call_arr[i]: self.prev_tool_call_arr[i]["name"] = func_name self._tool_call_ids[i] = tool_id or "" tool_call_deltas.append( DeltaToolCall( index=i, id=tool_id, type="function", function=DeltaFunctionCall( name=func_name, arguments="", ), ) ) # Diff the arguments and emit any new characters. diff = self._compute_args_diff(i, args_so_far) if diff: tool_call_deltas.append( DeltaToolCall( index=i, function=DeltaFunctionCall(arguments=diff), ) ) if regions: self.current_tool_id = len(regions) - 1 # 3. Return a delta if we have content or tool-call updates. if content or tool_call_deltas: return DeltaMessage( content=content, tool_calls=tool_call_deltas if tool_call_deltas else None, ) # Empty delta with token ids means EOS or closing tag — return # non-None so the serving framework can finalize finish_reason. if not delta_text and delta_token_ids and self.prev_tool_call_arr: return DeltaMessage(content="") return None