# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Kimi-K2 Reasoning Parser — MTP-compatible version. Fixes applied over the upstream parser: 1. **/ tag suppression no longer requires single-token deltas.** The original used ``len(delta_token_ids) == 1`` to detect and suppress think tags. With MTP speculative decoding, these tokens arrive fused with reasoning text, so the guard fails and raw tags leak into the reasoning or content output. 2. **Text-based detection replaces token-ID-only detection** in ``extract_reasoning_streaming``. Since ```` and ```` are single tokens, they always appear as complete strings in ``delta_text`` (the detokenizer never splits a single token across deltas). Text-based stripping is therefore safe and MTP-agnostic. 3. **Handles ```` + ``<|tool_calls_section_begin|>`` arriving in the same delta** — the reasoning portion is correctly terminated and the tool-call content is forwarded so the tool parser can detect it on the same or next call. Drop-in replacement: same class name, same interface. """ from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser if TYPE_CHECKING: from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.responses.protocol import ResponsesRequest class KimiK2ReasoningParser(ReasoningParser): """ Reasoning parser for Kimi K2 model — MTP-compatible. Uses ``...`` to denote reasoning text. Reasoning may also end implicitly when ``<|tool_calls_section_begin|>`` appears. All detection uses text-based matching so the parser is agnostic to how many tokens arrive per streaming step (robust against MTP and EAGLE speculative decoding). """ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): super().__init__(tokenizer, *args, **kwargs) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " "constructor during construction." ) # Check if thinking is disabled via chat_template_kwargs chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} thinking = bool(chat_kwargs.get("thinking", True)) # If thinking is not enabled, use identity parser to fall through self._identity_parser: IdentityReasoningParser | None if not thinking: self._identity_parser = IdentityReasoningParser( tokenizer, *args, **kwargs ) else: self._identity_parser = None # Token definitions self._start_token = "" self._end_token = "" self._tool_section_start_token = "<|tool_calls_section_begin|>" # Also support singular variant for tool section self._tool_section_start_variants = [ "<|tool_calls_section_begin|>", "<|tool_call_section_begin|>", ] # Get token IDs (used by is_reasoning_end for non-streaming, # and is_reasoning_end_streaming for delta checks) self._start_token_id = self.vocab.get(self._start_token) self._end_token_id = self.vocab.get(self._end_token) self._tool_section_start_token_id = self.vocab.get( self._tool_section_start_token ) # Collect all tool section start token IDs (for ID-based checks) self._tool_section_start_token_ids: set[int] = set() for variant in self._tool_section_start_variants: tid = self.vocab.get(variant) if tid is not None: self._tool_section_start_token_ids.add(tid) if self._start_token_id is None or self._end_token_id is None: raise RuntimeError( "KimiK2ReasoningParser could not locate think start/end " "tokens in the tokenizer!" ) # Streaming state — tracks reasoning within the CURRENT # generation only, avoiding false positives from prior turns' # tokens that appear in the prompt token IDs. self._reasoning_ended: bool = False # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _find_tool_section_start(self, text: str) -> int: """Return the index of the earliest tool-section-start marker, or -1 if none found.""" best = -1 for variant in self._tool_section_start_variants: idx = text.find(variant) if idx != -1 and (best == -1 or idx < best): best = idx return best def _strip_think_tags(self, text: str) -> str: """Remove ```` and ```` tag text from *text*.""" return text.replace(self._start_token, "").replace(self._end_token, "") def _strip_tool_section_markers(self, text: str) -> str: """Remove all tool-section start markers from *text*. The tool parser finds these in ``current_text`` independently; forwarding them as content causes double-handling. """ for variant in self._tool_section_start_variants: text = text.replace(variant, "") return text # ------------------------------------------------------------------ # Full-sequence methods (these scan all IDs — MTP-safe as-is) # ------------------------------------------------------------------ def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: """Check if reasoning has ended based on the token ID sequence. Scans backward to find the last think-start or think-end token. Returns True only if the last relevant token is a think-end or a tool-section-start, AND there is no think-start after it. CRITICAL: When called with prompt_token_ids (as the vLLM serving layer does), the input contains the full chat history. On multi-turn conversations, the prompt ends with tokens from the prior assistant message, which may include think-end. However, this think-end belongs to the PRIOR generation — the new generation will start its own reasoning with think-start. To handle this correctly, we check whether the input ends with a complete reasoning block (think-start ... think-end). If the last think token is think-end AND it's followed by non-reasoning tokens (like tool_call tokens or end-of-sequence), we return True. But if the input is just the prompt with no generated tokens yet, we return False because the new generation hasn't started reasoning yet. The key insight: in the chat template for multi-turn, after the last assistant message's think-end, the template adds <|im_end|> followed by new user/assistant markers. The assistant generation prompt ends with <|im_assistant|> and <|im_middle|> — no think tokens. So if we scan backward and find think-end but then find prompt-end tokens (not think-start) after it, we know reasoning ended in a PRIOR turn, not the current one. We return False to let the new generation start fresh. """ if self._identity_parser is not None: return self._identity_parser.is_reasoning_end(input_ids) # Scan backward to find the last think-start or think-end # or tool-section-start token. last_start = -1 last_end = -1 last_tool_section = -1 for i in range(len(input_ids) - 1, -1, -1): if input_ids[i] == self._start_token_id and last_start == -1: last_start = i if input_ids[i] == self._end_token_id and last_end == -1: last_end = i if input_ids[i] in self._tool_section_start_token_ids and last_tool_section == -1: last_tool_section = i # Stop early if we found think-start — it's the boundary if last_start != -1: break # No think tokens at all — not a reasoning model output if last_start == -1 and last_end == -1 and last_tool_section == -1: return False # think-start is the last relevant token — reasoning is in progress if last_start != -1 and (last_end == -1 or last_start > last_end): return False # think-end or tool-section is the last relevant token. # This could be from the prompt (prior turn) or from generated # tokens. For prompt tokens on multi-turn, the think-end is # from a prior assistant message and the new generation hasn't # started yet — we should return False. # # Heuristic: if think-end appears but is followed by more tokens # (like <|im_end|>, user markers, etc.), it's from the prompt # and reasoning hasn't started in the current generation yet. # Return False. # # If think-end is the very last token or near the end, it's # from generated tokens and reasoning has ended. Return True. last_relevant = max(last_end, last_tool_section) tokens_after = len(input_ids) - 1 - last_relevant # If there are more than a few tokens after the last think-end, # those are prompt tokens (chat template wrapping), meaning # the think-end is from a prior turn. Return False. if tokens_after > 3: return False return True def is_reasoning_end_streaming( self, input_ids: Sequence[int], delta_ids: Iterable[int] ) -> bool: """Check if reasoning ends in this delta.""" if self._identity_parser is not None: return self._identity_parser.is_reasoning_end_streaming( input_ids, delta_ids ) delta_ids_set = set(delta_ids) if self._end_token_id in delta_ids_set: return True return bool(delta_ids_set & self._tool_section_start_token_ids) def extract_content_ids(self, input_ids: list[int]) -> list[int]: """Extract content token IDs (everything after reasoning ends).""" if self._identity_parser is not None: return self._identity_parser.extract_content_ids(input_ids) if self._end_token_id in input_ids: end_idx = ( len(input_ids) - 1 - input_ids[::-1].index(self._end_token_id) ) if end_idx != -1: return input_ids[end_idx + 1:] # Check for implicit reasoning end via tool section for tid in self._tool_section_start_token_ids: if tid in input_ids: tool_idx = ( len(input_ids) - 1 - input_ids[::-1].index(tid) ) if tool_idx != -1: return input_ids[tool_idx:] return [] # ------------------------------------------------------------------ # Non-streaming extraction # ------------------------------------------------------------------ def extract_reasoning( self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest", ) -> tuple[str | None, str | None]: """Extract (reasoning, content) from complete model output.""" if self._identity_parser is not None: return self._identity_parser.extract_reasoning( model_output, request ) # Consume at the start if present start_idx = model_output.find(self._start_token) start_idx = 0 if start_idx != 0 else len(self._start_token) # Look for explicit end_idx = model_output.find(self._end_token) if end_idx != -1: reasoning = model_output[start_idx:end_idx] content = model_output[end_idx + len(self._end_token):] return reasoning, content or None # Look for implicit reasoning end via tool section tool_idx = self._find_tool_section_start(model_output) if tool_idx != -1: reasoning = model_output[start_idx:tool_idx] content = model_output[tool_idx:] return reasoning, content or None # Still reasoning (no content yet) return model_output[start_idx:], None # ------------------------------------------------------------------ # Streaming extraction — MTP-compatible # ------------------------------------------------------------------ def extract_reasoning_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> DeltaMessage | None: """Extract reasoning from a streaming delta. Uses **text-based** detection to strip ````/```` tags. This is safe because these are single tokens — the detokenizer always produces them as complete strings, never split across deltas. This makes the method agnostic to how many tokens arrive per step (MTP-compatible). """ if self._identity_parser is not None: return self._identity_parser.extract_reasoning_streaming( previous_text, current_text, delta_text, previous_token_ids, current_token_ids, delta_token_ids, ) # Reset state on new stream — previous_text is empty on the # first delta of each generation. if not previous_text: self._reasoning_ended = False # ── Already past reasoning → everything is content ── # Uses our own _reasoning_ended flag instead of scanning # previous_token_ids, which may contain from prior # assistant turns in the prompt and cause false positives. if self._reasoning_ended: cleaned = self._strip_tool_section_markers( self._strip_think_tags(delta_text) ) return DeltaMessage(content=cleaned) if cleaned else None # ── Check for in this delta ── if self._end_token in delta_text: self._reasoning_ended = True end_idx = delta_text.find(self._end_token) reasoning = self._strip_think_tags(delta_text[:end_idx]) content = self._strip_tool_section_markers( delta_text[end_idx + len(self._end_token):] ) kwargs: dict = {} if reasoning: kwargs["reasoning"] = reasoning if content: kwargs["content"] = content return DeltaMessage(**kwargs) if kwargs else None # ── Check for implicit reasoning end via tool section ── tool_idx = self._find_tool_section_start(delta_text) if tool_idx != -1: self._reasoning_ended = True reasoning = self._strip_think_tags(delta_text[:tool_idx]) kwargs = {} if reasoning: kwargs["reasoning"] = reasoning return DeltaMessage(**kwargs) if kwargs else None # ── Still in reasoning — strip tag if present ── cleaned = self._strip_think_tags(delta_text) return DeltaMessage(reasoning=cleaned) if cleaned else None