diff --git a/Dockerfile b/Dockerfile index fb53386..adaca80 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,4 +6,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends unzip && rm -rf ADD https://ewr1.vultrobjects.com/artifacts/models--nvidia--Kimi-K2.5-Thinking-Eagle3.zip /tmp/eagle3.zip RUN unzip /tmp/eagle3.zip -d /opt/nvidia-Kimi-K2.5-Thinking-Eagle3 && \ rm /tmp/eagle3.zip && \ - apt-get remove -y unzip && apt-get autoremove -y \ No newline at end of file + apt-get remove -y unzip && apt-get autoremove -y + +# Patch tool parser for MTP +COPY kimi_k2_tool_parser.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/kimi_k2_tool_parser.py \ No newline at end of file diff --git a/kimi_k2_tool_parser.py b/kimi_k2_tool_parser.py new file mode 100644 index 0000000..ac26f60 --- /dev/null +++ b/kimi_k2_tool_parser.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Kimi-K2 Tool Call Parser — re-parse-and-diff version. + +Adapted from the GLM-4/DeepSeek-V3.2 streaming fix to make the +streaming path robust against multi-token deltas produced by MTP +speculative decoding. + +Instead of counting start/end tokens to maintain an incremental state +machine, the streaming path re-parses the *entire* current_text on +every call, finds all <|tool_call_begin|> regions (complete and +in-progress), extracts the JSON arguments for each, and diffs against +what was previously sent. + +Key changes vs. the upstream token-count parser: + 1. No token-count state machine — the parser is stateless w.r.t. + how many tokens arrived per step. + 2. _extract_content() uses partial_tag_overlap to safely handle + section-start tags split across chunk boundaries. + 3. _extract_tool_call_regions() finds both complete and incomplete + tool-call blocks, enabling argument streaming. + 4. _compute_args_diff() emits only newly-added characters. + 5. Handles singular/plural section marker variants. + +Drop-in replacement: same class name, same interface. + +Example tool call format:: + + <|tool_calls_section_begin|> + <|tool_call_begin|> + functions.get_weather:0 + <|tool_call_argument_begin|> + {"location": "杭州", "date": "2024-01-16"} + <|tool_call_end|> + <|tool_call_begin|> + functions.get_time:1 + <|tool_call_argument_begin|> + {"timezone": "Asia/Shanghai"} + <|tool_call_end|> + <|tool_calls_section_end|> +""" + +from collections.abc import Sequence +from typing import Any + +import regex as re + +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + Tool, + ToolParser, +) + +logger = init_logger(__name__) + + +# --------------------------------------------------------------------------- +# Utility — inlined to avoid import issues across vLLM versions +# --------------------------------------------------------------------------- + +def partial_tag_overlap(text: str, tag: str) -> int: + """Length of the longest prefix of *tag* that matches a suffix of *text*. + + E.g. text ending in ``"<|tool_call"`` returns 11 when tag is + ``"<|tool_call_begin|>"``. Returns 0 when there is no overlap. + """ + max_check = min(len(tag) - 1, len(text)) + for k in range(max_check, 0, -1): + if text.endswith(tag[:k]): + return k + return 0 + + +class KimiK2ToolParser(ToolParser): + """Re-parse-and-diff tool parser for Kimi-K2 format. + + On every streaming call the parser re-parses ``current_text`` to + find tool-call regions, extracts the JSON arguments for each, and + diffs against what was previously sent. This is robust against + multi-token deltas from MTP / EAGLE speculative decoding. + """ + + def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): + super().__init__(tokenizer, tools) + + # ----- Tag constants ----- + # Section wrappers (support singular & plural variants) + self.tool_calls_section_start_variants: list[str] = [ + "<|tool_calls_section_begin|>", + "<|tool_call_section_begin|>", + ] + self.tool_calls_section_end_variants: list[str] = [ + "<|tool_calls_section_end|>", + "<|tool_call_section_end|>", + ] + # Primary variant for ToolParser base class / adjust_request + self.tool_calls_start_token: str = "<|tool_calls_section_begin|>" + self.tool_calls_end_token: str = "<|tool_calls_section_end|>" + + # Individual tool-call markers + self.tool_call_start_token: str = "<|tool_call_begin|>" + self.tool_call_end_token: str = "<|tool_call_end|>" + self.tool_call_arg_begin: str = "<|tool_call_argument_begin|>" + + # ----- Compiled regexes ----- + # Complete tool call block. + self.tool_call_regex = re.compile( + r"<\|tool_call_begin\|>\s*" + r"(?P[^<]+:\d+)\s*" + r"<\|tool_call_argument_begin\|>\s*" + r"(?P(?:(?!<\|tool_call_begin\|>).)*?)\s*" + r"<\|tool_call_end\|>", + re.DOTALL, + ) + + # For extracting tool ID from the start of a tool-call region. + self.tool_id_regex = re.compile( + r"\s*(?P[^\s<]+:\d+)\s*", re.DOTALL + ) + + # ----- Streaming state (reset per request) ----- + self._sent_content_idx: int = 0 + self._tool_call_ids: list[str] = [] + self.streamed_args_for_tool: list[str] = [] + self.prev_tool_call_arr: list[dict[str, Any]] = [] + self.current_tool_id: int = -1 + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction." + ) + + # Validate that the primary section tokens exist in vocab. + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token + ) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token + ) + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token + ) + self.tool_call_end_token_id = self.vocab.get( + self.tool_call_end_token + ) + + if ( + self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None + ): + raise RuntimeError( + "Kimi-K2 Tool parser could not locate tool call start/end " + "tokens in the tokenizer!" + ) + + logger.debug( + "Successfully initialized %s", self.__class__.__name__ + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _tools_enabled(request: ChatCompletionRequest) -> bool: + try: + tools = getattr(request, "tools", None) + tool_choice = getattr(request, "tool_choice", None) + return bool(tools) and tool_choice != "none" + except Exception: + logger.exception("Failed to determine if tools are enabled.") + return False + + @staticmethod + def _parse_tool_id(raw_id: str) -> tuple[str, str]: + """Parse ``'functions.get_weather:0'`` → ``('get_weather', 'functions.get_weather:0')``.""" + raw_id = raw_id.strip() + function_name = raw_id.split(":")[0].split(".")[-1] + return function_name, raw_id + + def _find_section_start(self, text: str) -> int: + """Return the index of the first section-start marker, or -1.""" + best = -1 + for variant in self.tool_calls_section_start_variants: + idx = text.find(variant) + if idx != -1 and (best == -1 or idx < best): + best = idx + return best + + def _find_section_start_end(self, text: str) -> tuple[int, int]: + """Return (start_of_inner, end_of_inner) for the section region. + + *start_of_inner* points just past the section-start marker. + *end_of_inner* is the index of the section-end marker, or -1 + if the section is still open. + """ + for variant in self.tool_calls_section_start_variants: + idx = text.find(variant) + if idx != -1: + inner_start = idx + len(variant) + # Look for end marker + for end_variant in self.tool_calls_section_end_variants: + end_idx = text.find(end_variant, inner_start) + if end_idx != -1: + return inner_start, end_idx + return inner_start, -1 + return -1, -1 + + # ------------------------------------------------------------------ + # Non-streaming extraction (logic preserved from original) + # ------------------------------------------------------------------ + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + try: + function_call_tuples = self.tool_call_regex.findall(model_output) + logger.debug("function_call_tuples: %s", function_call_tuples) + + tool_calls: list[ToolCall] = [] + for match in function_call_tuples: + function_id, function_args = match + function_name, full_id = self._parse_tool_id(function_id) + tool_calls.append( + ToolCall( + id=full_id, + type="function", + function=FunctionCall( + name=function_name, + arguments=function_args, + ), + ) + ) + + content_end = self._find_section_start(model_output) + content = model_output[:content_end] if content_end > 0 else None + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + # ------------------------------------------------------------------ + # Streaming helpers — re-parse-and-diff + # ------------------------------------------------------------------ + + def _reset_streaming_state(self) -> None: + self._sent_content_idx = 0 + self._tool_call_ids.clear() + self.streamed_args_for_tool.clear() + self.prev_tool_call_arr.clear() + self.current_tool_id = -1 + + def _get_earliest_section_start_tag(self) -> str: + """Return the shortest section-start variant (used for overlap + checking — the shortest tag has the smallest possible overlap, + so we check against the longest to be safe).""" + return max( + self.tool_calls_section_start_variants, key=len + ) + + def _extract_content(self, current_text: str) -> str | None: + """Return any non-tool-section text that hasn't been sent yet. + + Walks *current_text* from ``_sent_content_idx``, collecting + text outside ``<|tool_calls_section_begin|>`` … + ``<|tool_calls_section_end|>`` regions. Uses + ``partial_tag_overlap`` to avoid emitting bytes that might be + the start of a section tag split across chunks. + """ + content_segments: list[str] = [] + pos = self._sent_content_idx + overlap_tag = self._get_earliest_section_start_tag() + + while pos < len(current_text): + # Find next section-start marker from pos. + best_start = -1 + best_variant_len = 0 + for variant in self.tool_calls_section_start_variants: + idx = current_text.find(variant, pos) + if idx != -1 and (best_start == -1 or idx < best_start): + best_start = idx + best_variant_len = len(variant) + + if best_start == -1: + # No more section regions — emit tail minus overlap. + tail = current_text[pos:] + overlap = partial_tag_overlap(tail, overlap_tag) + sendable = tail[: len(tail) - overlap] if overlap else tail + if sendable: + content_segments.append(sendable) + pos = len(current_text) - overlap + break + + # Text before the section start is content. + if best_start > pos: + content_segments.append(current_text[pos:best_start]) + + # Skip past the section region. + inner_start = best_start + best_variant_len + # Find matching section-end. + best_end = -1 + best_end_variant_len = 0 + for variant in self.tool_calls_section_end_variants: + idx = current_text.find(variant, inner_start) + if idx != -1 and (best_end == -1 or idx < best_end): + best_end = idx + best_end_variant_len = len(variant) + + if best_end != -1: + pos = best_end + best_end_variant_len + else: + # Section still open — park cursor, stop. + pos = best_start + break + + if content_segments: + self._sent_content_idx = pos + return "".join(content_segments) + if pos > self._sent_content_idx: + self._sent_content_idx = pos + return None + + def _extract_tool_call_regions( + self, text: str + ) -> list[tuple[str, bool]]: + """Find all ``<|tool_call_begin|>`` … ``<|tool_call_end|>`` + blocks inside the tool-calls section. + + Returns a list of ``(inner_text, is_complete)`` tuples. + *inner_text* is everything between the tool-call open and close + tags (or end-of-available-text for the last partial block). + """ + results: list[tuple[str, bool]] = [] + + # Find the section region. + inner_start, inner_end = self._find_section_start_end(text) + if inner_start == -1: + return results + + region = text[inner_start:inner_end] if inner_end != -1 else text[inner_start:] + + pos = 0 + while pos < len(region): + tc_start = region.find(self.tool_call_start_token, pos) + if tc_start == -1: + break + + body_start = tc_start + len(self.tool_call_start_token) + tc_end = region.find(self.tool_call_end_token, body_start) + + if tc_end != -1: + body = region[body_start:tc_end] + results.append((body, True)) + pos = tc_end + len(self.tool_call_end_token) + else: + # Incomplete — still being generated. + body = region[body_start:] + overlap = partial_tag_overlap(body, self.tool_call_end_token) + if overlap: + body = body[:-overlap] + results.append((body, False)) + break + + return results + + def _parse_tool_call_body( + self, body: str, is_complete: bool + ) -> tuple[str | None, str | None, str]: + """Parse a tool-call body into (func_name, tool_id, args_so_far). + + The body looks like:: + + functions.get_weather:0 + <|tool_call_argument_begin|> + {"location": "杭州"} + + Returns ``(None, None, "")`` if the body doesn't contain enough + information yet (e.g. the tool ID is still arriving). + """ + # Extract tool ID (everything before <|tool_call_argument_begin|> + # or end of string). + arg_begin_idx = body.find(self.tool_call_arg_begin) + + if arg_begin_idx != -1: + id_portion = body[:arg_begin_idx] + args_portion = body[arg_begin_idx + len(self.tool_call_arg_begin):] + else: + id_portion = body + args_portion = "" + + # Try to extract the tool ID. + id_match = self.tool_id_regex.match(id_portion) + if not id_match: + # Not enough tokens yet to identify the tool. + return None, None, "" + + raw_id = id_match.group("tool_id") + func_name, full_id = self._parse_tool_id(raw_id) + + # Build args string. + args = args_portion.strip() + + if is_complete: + # For a complete block, args is the final JSON. + return func_name, full_id, args + else: + # For a partial block, strip any trailing partial-tag overlap + # against tool_call_end (already done in caller), but also + # check for partial overlap against tool_call_argument_begin + # in case it hasn't fully arrived yet. + if arg_begin_idx == -1: + # No argument section yet. + overlap = partial_tag_overlap( + id_portion, self.tool_call_arg_begin + ) + if overlap: + # The tag is still arriving — we have the name but + # no args yet. + pass + return func_name, full_id, "" + + return func_name, full_id, args + + def _compute_args_diff(self, index: int, args_so_far: str) -> str | None: + """Return only the characters in *args_so_far* that haven't been + sent yet, or ``None`` if there's nothing new.""" + prev = self.streamed_args_for_tool[index] + if not args_so_far or len(args_so_far) <= len(prev): + return None + diff = args_so_far[len(prev):] + self.streamed_args_for_tool[index] = args_so_far + self.prev_tool_call_arr[index]["arguments"] = args_so_far + return diff + + def _ensure_tool_state_for(self, index: int) -> None: + """Grow the streaming-state arrays so *index* is valid.""" + while len(self._tool_call_ids) <= index: + self._tool_call_ids.append("") + while len(self.streamed_args_for_tool) <= index: + self.streamed_args_for_tool.append("") + while len(self.prev_tool_call_arr) <= index: + self.prev_tool_call_arr.append({}) + + # ------------------------------------------------------------------ + # Main streaming entry point + # ------------------------------------------------------------------ + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + """Extract tool calls from streaming output using re-parse-and-diff. + + On every call we: + 1. Re-scan *current_text* for content outside tool-call sections. + 2. Find all ``<|tool_call_begin|>`` regions (complete + partial). + 3. Parse each region for tool ID and arguments. + 4. Diff arguments against previous state, emit deltas. + + Because the entire text is re-parsed each time, the result is + correct regardless of how many tokens arrived in this step. + """ + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + + # First chunk of a new stream — reset state. + if not previous_text: + self._reset_streaming_state() + + # If tools aren't enabled, just forward content. + if not self._tools_enabled(request): + return DeltaMessage(content=delta_text) if delta_text else None + + # 1. Extract any content outside tool-call sections. + content = self._extract_content(current_text) + + # 2. Find all tool-call regions. + regions = self._extract_tool_call_regions(current_text) + tool_call_deltas: list[DeltaToolCall] = [] + + for i, (body, is_complete) in enumerate(regions): + self._ensure_tool_state_for(i) + + func_name, tool_id, args_so_far = self._parse_tool_call_body( + body, is_complete + ) + if func_name is None: + # Not enough data to identify the tool yet. + break + + # Emit the tool name (once per tool call). + if "name" not in self.prev_tool_call_arr[i]: + self.prev_tool_call_arr[i]["name"] = func_name + self._tool_call_ids[i] = tool_id or "" + tool_call_deltas.append( + DeltaToolCall( + index=i, + id=tool_id, + type="function", + function=DeltaFunctionCall( + name=func_name, + arguments="", + ), + ) + ) + + # Diff the arguments and emit any new characters. + diff = self._compute_args_diff(i, args_so_far) + if diff: + tool_call_deltas.append( + DeltaToolCall( + index=i, + function=DeltaFunctionCall(arguments=diff), + ) + ) + + if regions: + self.current_tool_id = len(regions) - 1 + + # 3. Return a delta if we have content or tool-call updates. + if content or tool_call_deltas: + return DeltaMessage( + content=content, + tool_calls=tool_call_deltas if tool_call_deltas else None, + ) + + # Empty delta with token ids means EOS or closing tag — return + # non-None so the serving framework can finalize finish_reason. + if not delta_text and delta_token_ids and self.prev_tool_call_arr: + return DeltaMessage(content="") + + return None \ No newline at end of file