Files
vllm-kimi25-eagle/kimi_k2_reasoning_parser.py
biondizzle f5266646eb Make is_reasoning_end() always return False
The vLLM serving layer calls is_reasoning_end() with prompt_token_ids
to pre-compute whether reasoning has ended before streaming starts. On
multi-turn conversations, prompt_token_ids contains think-end tokens
from prior assistant messages in the chat history. This causes a false
positive — the serving layer sets reasoning_end_arr[i] = True, skips
extract_reasoning_streaming entirely, and routes all thinking text to
content.

By returning False, the serving layer always calls
extract_reasoning_streaming, which correctly tracks reasoning state
via _reasoning_ended based only on the model's generated text.
2026-04-14 08:21:14 +00:00

313 lines
13 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Kimi-K2 Reasoning Parser — MTP-compatible version.
Fixes applied over the upstream parser:
1. **<think>/</think> tag suppression no longer requires single-token
deltas.** The original used ``len(delta_token_ids) == 1`` to detect
and suppress think tags. With MTP speculative decoding, these tokens
arrive fused with reasoning text, so the guard fails and raw tags
leak into the reasoning or content output.
2. **Text-based detection replaces token-ID-only detection** in
``extract_reasoning_streaming``. Since ``<think>`` and ``</think>``
are single tokens, they always appear as complete strings in
``delta_text`` (the detokenizer never splits a single token across
deltas). Text-based stripping is therefore safe and MTP-agnostic.
3. **Handles ``</think>`` + ``<|tool_calls_section_begin|>`` arriving
in the same delta** — the reasoning portion is correctly terminated
and the tool-call content is forwarded so the tool parser can
detect it on the same or next call.
Drop-in replacement: same class name, same interface.
"""
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
class KimiK2ReasoningParser(ReasoningParser):
"""
Reasoning parser for Kimi K2 model — MTP-compatible.
Uses ``<think>...</think>`` to denote reasoning text. Reasoning
may also end implicitly when ``<|tool_calls_section_begin|>``
appears.
All detection uses text-based matching so the parser is agnostic
to how many tokens arrive per streaming step (robust against MTP
and EAGLE speculative decoding).
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction."
)
# Check if thinking is disabled via chat_template_kwargs
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
thinking = bool(chat_kwargs.get("thinking", True))
# If thinking is not enabled, use identity parser to fall through
self._identity_parser: IdentityReasoningParser | None
if not thinking:
self._identity_parser = IdentityReasoningParser(
tokenizer, *args, **kwargs
)
else:
self._identity_parser = None
# Token definitions
self._start_token = "<think>"
self._end_token = "</think>"
self._tool_section_start_token = "<|tool_calls_section_begin|>"
# Also support singular variant for tool section
self._tool_section_start_variants = [
"<|tool_calls_section_begin|>",
"<|tool_call_section_begin|>",
]
# Get token IDs (used by is_reasoning_end for non-streaming,
# and is_reasoning_end_streaming for delta checks)
self._start_token_id = self.vocab.get(self._start_token)
self._end_token_id = self.vocab.get(self._end_token)
self._tool_section_start_token_id = self.vocab.get(
self._tool_section_start_token
)
# Collect all tool section start token IDs (for ID-based checks)
self._tool_section_start_token_ids: set[int] = set()
for variant in self._tool_section_start_variants:
tid = self.vocab.get(variant)
if tid is not None:
self._tool_section_start_token_ids.add(tid)
if self._start_token_id is None or self._end_token_id is None:
raise RuntimeError(
"KimiK2ReasoningParser could not locate think start/end "
"tokens in the tokenizer!"
)
# Streaming state — tracks reasoning within the CURRENT
# generation only, avoiding false positives from prior turns'
# </think> tokens that appear in the prompt token IDs.
self._reasoning_ended: bool = False
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _find_tool_section_start(self, text: str) -> int:
"""Return the index of the earliest tool-section-start marker,
or -1 if none found."""
best = -1
for variant in self._tool_section_start_variants:
idx = text.find(variant)
if idx != -1 and (best == -1 or idx < best):
best = idx
return best
def _strip_think_tags(self, text: str) -> str:
"""Remove ``<think>`` and ``</think>`` tag text from *text*."""
return text.replace(self._start_token, "").replace(self._end_token, "")
def _strip_tool_section_markers(self, text: str) -> str:
"""Remove all tool-section start markers from *text*.
The tool parser finds these in ``current_text`` independently;
forwarding them as content causes double-handling.
"""
for variant in self._tool_section_start_variants:
text = text.replace(variant, "")
return text
# ------------------------------------------------------------------
# Full-sequence methods (these scan all IDs — MTP-safe as-is)
# ------------------------------------------------------------------
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
"""Check if reasoning has ended.
IMPORTANT: Always returns False for this parser. The reasoning
state is tracked internally by ``_reasoning_ended`` which is
updated only when ``extract_reasoning_streaming`` detects
think-end or a tool-section marker in the model's *generated*
text.
The vLLM serving layer calls this with ``prompt_token_ids`` to
pre-compute whether reasoning has ended. On multi-turn
conversations, the prompt contains think-end tokens from prior
assistant messages, which would cause a false positive — the
serving layer would skip ``extract_reasoning_streaming`` entirely
and route all thinking text to content.
Returning False ensures the serving layer always calls
``extract_reasoning_streaming``, which correctly handles the
transition using generated text only.
"""
if self._identity_parser is not None:
return self._identity_parser.is_reasoning_end(input_ids)
return False
def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool:
"""Check if reasoning ends in this delta."""
if self._identity_parser is not None:
return self._identity_parser.is_reasoning_end_streaming(
input_ids, delta_ids
)
delta_ids_set = set(delta_ids)
if self._end_token_id in delta_ids_set:
return True
return bool(delta_ids_set & self._tool_section_start_token_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""Extract content token IDs (everything after reasoning ends)."""
if self._identity_parser is not None:
return self._identity_parser.extract_content_ids(input_ids)
if self._end_token_id in input_ids:
end_idx = (
len(input_ids) - 1 - input_ids[::-1].index(self._end_token_id)
)
if end_idx != -1:
return input_ids[end_idx + 1:]
# Check for implicit reasoning end via tool section
for tid in self._tool_section_start_token_ids:
if tid in input_ids:
tool_idx = (
len(input_ids) - 1 - input_ids[::-1].index(tid)
)
if tool_idx != -1:
return input_ids[tool_idx:]
return []
# ------------------------------------------------------------------
# Non-streaming extraction
# ------------------------------------------------------------------
def extract_reasoning(
self,
model_output: str,
request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]:
"""Extract (reasoning, content) from complete model output."""
if self._identity_parser is not None:
return self._identity_parser.extract_reasoning(
model_output, request
)
# Consume <think> at the start if present
start_idx = model_output.find(self._start_token)
start_idx = 0 if start_idx != 0 else len(self._start_token)
# Look for explicit </think>
end_idx = model_output.find(self._end_token)
if end_idx != -1:
reasoning = model_output[start_idx:end_idx]
content = model_output[end_idx + len(self._end_token):]
return reasoning, content or None
# Look for implicit reasoning end via tool section
tool_idx = self._find_tool_section_start(model_output)
if tool_idx != -1:
reasoning = model_output[start_idx:tool_idx]
content = model_output[tool_idx:]
return reasoning, content or None
# Still reasoning (no content yet)
return model_output[start_idx:], None
# ------------------------------------------------------------------
# Streaming extraction — MTP-compatible
# ------------------------------------------------------------------
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""Extract reasoning from a streaming delta.
Uses **text-based** detection to strip ``<think>``/``</think>``
tags. This is safe because these are single tokens — the
detokenizer always produces them as complete strings, never
split across deltas. This makes the method agnostic to how
many tokens arrive per step (MTP-compatible).
"""
if self._identity_parser is not None:
return self._identity_parser.extract_reasoning_streaming(
previous_text, current_text, delta_text,
previous_token_ids, current_token_ids, delta_token_ids,
)
# Reset state on new stream — previous_text is empty on the
# first delta of each generation.
if not previous_text:
self._reasoning_ended = False
# ── Already past reasoning → everything is content ──
# Uses our own _reasoning_ended flag instead of scanning
# previous_token_ids, which may contain </think> from prior
# assistant turns in the prompt and cause false positives.
if self._reasoning_ended:
cleaned = self._strip_tool_section_markers(
self._strip_think_tags(delta_text)
)
return DeltaMessage(content=cleaned) if cleaned else None
# ── Check for </think> in this delta ──
if self._end_token in delta_text:
self._reasoning_ended = True
end_idx = delta_text.find(self._end_token)
reasoning = self._strip_think_tags(delta_text[:end_idx])
content = self._strip_tool_section_markers(
delta_text[end_idx + len(self._end_token):]
)
kwargs: dict = {}
if reasoning:
kwargs["reasoning"] = reasoning
if content:
kwargs["content"] = content
return DeltaMessage(**kwargs) if kwargs else None
# ── Check for implicit reasoning end via tool section ──
tool_idx = self._find_tool_section_start(delta_text)
if tool_idx != -1:
self._reasoning_ended = True
reasoning = self._strip_think_tags(delta_text[:tool_idx])
kwargs = {}
if reasoning:
kwargs["reasoning"] = reasoning
return DeltaMessage(**kwargs) if kwargs else None
# ── Still in reasoning — strip <think> tag if present ──
cleaned = self._strip_think_tags(delta_text)
return DeltaMessage(reasoning=cleaned) if cleaned else None