diff --git a/Dockerfile b/Dockerfile index adaca80..17f12cd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,5 +8,7 @@ RUN unzip /tmp/eagle3.zip -d /opt/nvidia-Kimi-K2.5-Thinking-Eagle3 && \ rm /tmp/eagle3.zip && \ apt-get remove -y unzip && apt-get autoremove -y -# Patch tool parser for MTP -COPY kimi_k2_tool_parser.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/kimi_k2_tool_parser.py \ No newline at end of file +# Patch tool and reasoning parsers for Eagle +COPY kimi_k2_tool_parser.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/kimi_k2_tool_parser.py + +COPY kimi_k2_reasoning_parser.py /usr/local/lib/python3.12/dist-packages/vllm/reasoning/kimi_k2_reasoning_parser.py \ No newline at end of file diff --git a/kimi_k2_reasoning_parser.py b/kimi_k2_reasoning_parser.py new file mode 100644 index 0000000..1a43eaf --- /dev/null +++ b/kimi_k2_reasoning_parser.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Kimi-K2 Reasoning Parser — MTP-compatible version. + +Fixes applied over the upstream parser: + +1. **/ tag suppression no longer requires single-token + deltas.** The original used ``len(delta_token_ids) == 1`` to detect + and suppress think tags. With MTP speculative decoding, these tokens + arrive fused with reasoning text, so the guard fails and raw tags + leak into the reasoning or content output. + +2. **Text-based detection replaces token-ID-only detection** in + ``extract_reasoning_streaming``. Since ```` and ```` + are single tokens, they always appear as complete strings in + ``delta_text`` (the detokenizer never splits a single token across + deltas). Text-based stripping is therefore safe and MTP-agnostic. + +3. **Handles ```` + ``<|tool_calls_section_begin|>`` arriving + in the same delta** — the reasoning portion is correctly terminated + and the tool-call content is forwarded so the tool parser can + detect it on the same or next call. + +Drop-in replacement: same class name, same interface. +""" + +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.engine.protocol import DeltaMessage +from vllm.reasoning.abs_reasoning_parsers import ReasoningParser +from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser + +if TYPE_CHECKING: + from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest + from vllm.entrypoints.openai.responses.protocol import ResponsesRequest + + +class KimiK2ReasoningParser(ReasoningParser): + """ + Reasoning parser for Kimi K2 model — MTP-compatible. + + Uses ``...`` to denote reasoning text. Reasoning + may also end implicitly when ``<|tool_calls_section_begin|>`` + appears. + + All detection uses text-based matching so the parser is agnostic + to how many tokens arrive per streaming step (robust against MTP + and EAGLE speculative decoding). + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction." + ) + + # Check if thinking is disabled via chat_template_kwargs + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", True)) + + # If thinking is not enabled, use identity parser to fall through + self._identity_parser: IdentityReasoningParser | None + if not thinking: + self._identity_parser = IdentityReasoningParser( + tokenizer, *args, **kwargs + ) + else: + self._identity_parser = None + + # Token definitions + self._start_token = "" + self._end_token = "" + self._tool_section_start_token = "<|tool_calls_section_begin|>" + + # Also support singular variant for tool section + self._tool_section_start_variants = [ + "<|tool_calls_section_begin|>", + "<|tool_call_section_begin|>", + ] + + # Get token IDs (used by is_reasoning_end which scans full ID lists) + self._start_token_id = self.vocab.get(self._start_token) + self._end_token_id = self.vocab.get(self._end_token) + self._tool_section_start_token_id = self.vocab.get( + self._tool_section_start_token + ) + + # Collect all tool section start token IDs (for ID-based checks) + self._tool_section_start_token_ids: set[int] = set() + for variant in self._tool_section_start_variants: + tid = self.vocab.get(variant) + if tid is not None: + self._tool_section_start_token_ids.add(tid) + + if self._start_token_id is None or self._end_token_id is None: + raise RuntimeError( + "KimiK2ReasoningParser could not locate think start/end " + "tokens in the tokenizer!" + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _find_tool_section_start(self, text: str) -> int: + """Return the index of the earliest tool-section-start marker, + or -1 if none found.""" + best = -1 + for variant in self._tool_section_start_variants: + idx = text.find(variant) + if idx != -1 and (best == -1 or idx < best): + best = idx + return best + + def _strip_think_tags(self, text: str) -> str: + """Remove ```` and ```` tag text from *text*.""" + return text.replace(self._start_token, "").replace(self._end_token, "") + + # ------------------------------------------------------------------ + # Full-sequence methods (these scan all IDs — MTP-safe as-is) + # ------------------------------------------------------------------ + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + """Check if reasoning has ended by scanning the full token sequence. + + Reasoning ends when we see either ```` or a tool-section + start token after the last ````. + """ + if self._identity_parser is not None: + return self._identity_parser.is_reasoning_end(input_ids) + + for i in range(len(input_ids) - 1, -1, -1): + if input_ids[i] == self._start_token_id: + return False + if input_ids[i] == self._end_token_id: + return True + if input_ids[i] in self._tool_section_start_token_ids: + return True + return False + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Iterable[int] + ) -> bool: + """Check if reasoning ends in this delta.""" + if self._identity_parser is not None: + return self._identity_parser.is_reasoning_end_streaming( + input_ids, delta_ids + ) + + delta_ids_set = set(delta_ids) + if self._end_token_id in delta_ids_set: + return True + return bool(delta_ids_set & self._tool_section_start_token_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """Extract content token IDs (everything after reasoning ends).""" + if self._identity_parser is not None: + return self._identity_parser.extract_content_ids(input_ids) + + if self._end_token_id in input_ids: + end_idx = ( + len(input_ids) - 1 - input_ids[::-1].index(self._end_token_id) + ) + if end_idx != -1: + return input_ids[end_idx + 1:] + + # Check for implicit reasoning end via tool section + for tid in self._tool_section_start_token_ids: + if tid in input_ids: + tool_idx = ( + len(input_ids) - 1 - input_ids[::-1].index(tid) + ) + if tool_idx != -1: + return input_ids[tool_idx:] + + return [] + + # ------------------------------------------------------------------ + # Non-streaming extraction + # ------------------------------------------------------------------ + + def extract_reasoning( + self, + model_output: str, + request: "ChatCompletionRequest | ResponsesRequest", + ) -> tuple[str | None, str | None]: + """Extract (reasoning, content) from complete model output.""" + if self._identity_parser is not None: + return self._identity_parser.extract_reasoning( + model_output, request + ) + + # Consume at the start if present + start_idx = model_output.find(self._start_token) + start_idx = 0 if start_idx != 0 else len(self._start_token) + + # Look for explicit + end_idx = model_output.find(self._end_token) + if end_idx != -1: + reasoning = model_output[start_idx:end_idx] + content = model_output[end_idx + len(self._end_token):] + return reasoning, content or None + + # Look for implicit reasoning end via tool section + tool_idx = self._find_tool_section_start(model_output) + if tool_idx != -1: + reasoning = model_output[start_idx:tool_idx] + content = model_output[tool_idx:] + return reasoning, content or None + + # Still reasoning (no content yet) + return model_output[start_idx:], None + + # ------------------------------------------------------------------ + # Streaming extraction — MTP-compatible + # ------------------------------------------------------------------ + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """Extract reasoning from a streaming delta. + + Uses **text-based** detection to strip ````/```` + tags. This is safe because these are single tokens — the + detokenizer always produces them as complete strings, never + split across deltas. This makes the method agnostic to how + many tokens arrive per step (MTP-compatible). + """ + if self._identity_parser is not None: + return self._identity_parser.extract_reasoning_streaming( + previous_text, current_text, delta_text, + previous_token_ids, current_token_ids, delta_token_ids, + ) + + # ── Already past reasoning → everything is content ── + if self.is_reasoning_end(previous_token_ids): + # Strip any residual think tags that might appear in content + cleaned = self._strip_think_tags(delta_text) + return DeltaMessage(content=cleaned) if cleaned else None + + # ── Check for in this delta ── + if self._end_token in delta_text: + end_idx = delta_text.find(self._end_token) + # Everything before is reasoning (strip if present) + reasoning = self._strip_think_tags(delta_text[:end_idx]) + # Everything after is content + content = delta_text[end_idx + len(self._end_token):] + + kwargs: dict = {} + if reasoning: + kwargs["reasoning"] = reasoning + if content: + kwargs["content"] = content + return DeltaMessage(**kwargs) if kwargs else None + + # ── Check for implicit reasoning end via tool section ── + tool_idx = self._find_tool_section_start(delta_text) + if tool_idx != -1: + reasoning = self._strip_think_tags(delta_text[:tool_idx]) + # Forward the tool section marker as content so the tool + # parser can detect it. + content = delta_text[tool_idx:] + + kwargs = {} + if reasoning: + kwargs["reasoning"] = reasoning + if content: + kwargs["content"] = content + return DeltaMessage(**kwargs) if kwargs else None + + # ── Still in reasoning — strip tag if present ── + cleaned = self._strip_think_tags(delta_text) + return DeltaMessage(reasoning=cleaned) if cleaned else None \ No newline at end of file diff --git a/kimi_k2_tool_parser.py b/kimi_k2_tool_parser.py index b0e8897..3e0065d 100644 --- a/kimi_k2_tool_parser.py +++ b/kimi_k2_tool_parser.py @@ -57,6 +57,7 @@ from vllm.entrypoints.openai.engine.protocol import ( FunctionCall, ToolCall, ) +from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( @@ -171,6 +172,20 @@ class KimiK2ToolParser(ToolParser): "Successfully initialized %s", self.__class__.__name__ ) + # ------------------------------------------------------------------ + # Request adjustment + # ------------------------------------------------------------------ + + def adjust_request( + self, request: ChatCompletionRequest | ResponsesRequest + ) -> ChatCompletionRequest | ResponsesRequest: + request = super().adjust_request(request) + if request.tools and request.tool_choice != "none": + # Ensure tool-call tokens (<|tool_calls_section_begin|>, + # <|tool_call_begin|>, etc.) are not stripped during decoding. + request.skip_special_tokens = False + return request + # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------