The streaming path was using is_reasoning_end(previous_token_ids) to check if reasoning had ended. On multi-turn conversations, previous_token_ids includes the entire chat history, including think-end tokens from prior assistant messages. This caused the parser to incorrectly think reasoning was already over before the model generated anything, routing all thinking text to content instead of reasoning. Fix: Replace the token-ID-based check with a text-based state variable (_reasoning_ended) that tracks reasoning end based solely on what the model has generated in the current turn. Reset on each new generation. Also includes the chat template for reference.
314 lines
13 KiB
Python
314 lines
13 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Kimi-K2 Reasoning Parser — MTP-compatible version.
|
|
|
|
Fixes applied over the upstream parser:
|
|
|
|
1. **<think>/</think> tag suppression no longer requires single-token
|
|
deltas.** The original used ``len(delta_token_ids) == 1`` to detect
|
|
and suppress think tags. With MTP speculative decoding, these tokens
|
|
arrive fused with reasoning text, so the guard fails and raw tags
|
|
leak into the reasoning or content output.
|
|
|
|
2. **Text-based detection replaces token-ID-only detection** in
|
|
``extract_reasoning_streaming``. Since ``<think>`` and ``</think>``
|
|
are single tokens, they always appear as complete strings in
|
|
``delta_text`` (the detokenizer never splits a single token across
|
|
deltas). Text-based stripping is therefore safe and MTP-agnostic.
|
|
|
|
3. **Handles ``</think>`` + ``<|tool_calls_section_begin|>`` arriving
|
|
in the same delta** — the reasoning portion is correctly terminated
|
|
and the tool-call content is forwarded so the tool parser can
|
|
detect it on the same or next call.
|
|
|
|
Drop-in replacement: same class name, same interface.
|
|
"""
|
|
|
|
from collections.abc import Iterable, Sequence
|
|
from typing import TYPE_CHECKING
|
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
|
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
|
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
|
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
|
|
|
|
|
class KimiK2ReasoningParser(ReasoningParser):
|
|
"""
|
|
Reasoning parser for Kimi K2 model — MTP-compatible.
|
|
|
|
Uses ``<think>...</think>`` to denote reasoning text. Reasoning
|
|
may also end implicitly when ``<|tool_calls_section_begin|>``
|
|
appears.
|
|
|
|
All detection uses text-based matching so the parser is agnostic
|
|
to how many tokens arrive per streaming step (robust against MTP
|
|
and EAGLE speculative decoding).
|
|
"""
|
|
|
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
|
|
super().__init__(tokenizer, *args, **kwargs)
|
|
|
|
if not self.model_tokenizer:
|
|
raise ValueError(
|
|
"The model tokenizer must be passed to the ReasoningParser "
|
|
"constructor during construction."
|
|
)
|
|
|
|
# Check if thinking is disabled via chat_template_kwargs
|
|
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
|
thinking = bool(chat_kwargs.get("thinking", True))
|
|
|
|
# If thinking is not enabled, use identity parser to fall through
|
|
self._identity_parser: IdentityReasoningParser | None
|
|
if not thinking:
|
|
self._identity_parser = IdentityReasoningParser(
|
|
tokenizer, *args, **kwargs
|
|
)
|
|
else:
|
|
self._identity_parser = None
|
|
|
|
# Token definitions
|
|
self._start_token = "<think>"
|
|
self._end_token = "</think>"
|
|
self._tool_section_start_token = "<|tool_calls_section_begin|>"
|
|
|
|
# Also support singular variant for tool section
|
|
self._tool_section_start_variants = [
|
|
"<|tool_calls_section_begin|>",
|
|
"<|tool_call_section_begin|>",
|
|
]
|
|
|
|
# Get token IDs (used by is_reasoning_end which scans full ID lists)
|
|
self._start_token_id = self.vocab.get(self._start_token)
|
|
self._end_token_id = self.vocab.get(self._end_token)
|
|
self._tool_section_start_token_id = self.vocab.get(
|
|
self._tool_section_start_token
|
|
)
|
|
|
|
# Collect all tool section start token IDs (for ID-based checks)
|
|
self._tool_section_start_token_ids: set[int] = set()
|
|
for variant in self._tool_section_start_variants:
|
|
tid = self.vocab.get(variant)
|
|
if tid is not None:
|
|
self._tool_section_start_token_ids.add(tid)
|
|
|
|
if self._start_token_id is None or self._end_token_id is None:
|
|
raise RuntimeError(
|
|
"KimiK2ReasoningParser could not locate think start/end "
|
|
"tokens in the tokenizer!"
|
|
)
|
|
|
|
# Streaming state: has the model's *generated* reasoning ended?
|
|
# This tracks reasoning end based on generated text only, not
|
|
# prompt token IDs which may contain think-end from prior turns
|
|
# in multi-turn conversations.
|
|
self._reasoning_ended: bool = False
|
|
|
|
# ------------------------------------------------------------------
|
|
# Helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _find_tool_section_start(self, text: str) -> int:
|
|
"""Return the index of the earliest tool-section-start marker,
|
|
or -1 if none found."""
|
|
best = -1
|
|
for variant in self._tool_section_start_variants:
|
|
idx = text.find(variant)
|
|
if idx != -1 and (best == -1 or idx < best):
|
|
best = idx
|
|
return best
|
|
|
|
def _strip_think_tags(self, text: str) -> str:
|
|
"""Remove ``<think>`` and ``</think>`` tag text from *text*."""
|
|
return text.replace(self._start_token, "").replace(self._end_token, "")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Full-sequence methods (these scan all IDs — MTP-safe as-is)
|
|
# ------------------------------------------------------------------
|
|
|
|
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
|
"""Check if reasoning has ended by scanning the full token sequence.
|
|
|
|
Reasoning ends when we see either ``</think>`` or a tool-section
|
|
start token after the last ``<think>``.
|
|
"""
|
|
if self._identity_parser is not None:
|
|
return self._identity_parser.is_reasoning_end(input_ids)
|
|
|
|
for i in range(len(input_ids) - 1, -1, -1):
|
|
if input_ids[i] == self._start_token_id:
|
|
return False
|
|
if input_ids[i] == self._end_token_id:
|
|
return True
|
|
if input_ids[i] in self._tool_section_start_token_ids:
|
|
return True
|
|
return False
|
|
|
|
def is_reasoning_end_streaming(
|
|
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
|
) -> bool:
|
|
"""Check if reasoning ends in this delta."""
|
|
if self._identity_parser is not None:
|
|
return self._identity_parser.is_reasoning_end_streaming(
|
|
input_ids, delta_ids
|
|
)
|
|
|
|
delta_ids_set = set(delta_ids)
|
|
if self._end_token_id in delta_ids_set:
|
|
return True
|
|
return bool(delta_ids_set & self._tool_section_start_token_ids)
|
|
|
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
|
"""Extract content token IDs (everything after reasoning ends)."""
|
|
if self._identity_parser is not None:
|
|
return self._identity_parser.extract_content_ids(input_ids)
|
|
|
|
if self._end_token_id in input_ids:
|
|
end_idx = (
|
|
len(input_ids) - 1 - input_ids[::-1].index(self._end_token_id)
|
|
)
|
|
if end_idx != -1:
|
|
return input_ids[end_idx + 1:]
|
|
|
|
# Check for implicit reasoning end via tool section
|
|
for tid in self._tool_section_start_token_ids:
|
|
if tid in input_ids:
|
|
tool_idx = (
|
|
len(input_ids) - 1 - input_ids[::-1].index(tid)
|
|
)
|
|
if tool_idx != -1:
|
|
return input_ids[tool_idx:]
|
|
|
|
return []
|
|
|
|
# ------------------------------------------------------------------
|
|
# Non-streaming extraction
|
|
# ------------------------------------------------------------------
|
|
|
|
def extract_reasoning(
|
|
self,
|
|
model_output: str,
|
|
request: "ChatCompletionRequest | ResponsesRequest",
|
|
) -> tuple[str | None, str | None]:
|
|
"""Extract (reasoning, content) from complete model output."""
|
|
if self._identity_parser is not None:
|
|
return self._identity_parser.extract_reasoning(
|
|
model_output, request
|
|
)
|
|
|
|
# Consume <think> at the start if present
|
|
start_idx = model_output.find(self._start_token)
|
|
start_idx = 0 if start_idx != 0 else len(self._start_token)
|
|
|
|
# Look for explicit </think>
|
|
end_idx = model_output.find(self._end_token)
|
|
if end_idx != -1:
|
|
reasoning = model_output[start_idx:end_idx]
|
|
content = model_output[end_idx + len(self._end_token):]
|
|
return reasoning, content or None
|
|
|
|
# Look for implicit reasoning end via tool section
|
|
tool_idx = self._find_tool_section_start(model_output)
|
|
if tool_idx != -1:
|
|
reasoning = model_output[start_idx:tool_idx]
|
|
content = model_output[tool_idx:]
|
|
return reasoning, content or None
|
|
|
|
# Still reasoning (no content yet)
|
|
return model_output[start_idx:], None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Streaming extraction — MTP-compatible
|
|
# ------------------------------------------------------------------
|
|
|
|
def extract_reasoning_streaming(
|
|
self,
|
|
previous_text: str,
|
|
current_text: str,
|
|
delta_text: str,
|
|
previous_token_ids: Sequence[int],
|
|
current_token_ids: Sequence[int],
|
|
delta_token_ids: Sequence[int],
|
|
) -> DeltaMessage | None:
|
|
"""Extract reasoning from a streaming delta.
|
|
|
|
Uses **text-based** detection to strip ``<think>``/``</think>``
|
|
tags. This is safe because these are single tokens — the
|
|
detokenizer always produces them as complete strings, never
|
|
split across deltas. This makes the method agnostic to how
|
|
many tokens arrive per step (MTP-compatible).
|
|
"""
|
|
if self._identity_parser is not None:
|
|
return self._identity_parser.extract_reasoning_streaming(
|
|
previous_text, current_text, delta_text,
|
|
previous_token_ids, current_token_ids, delta_token_ids,
|
|
)
|
|
|
|
# First chunk of a new generation — reset state.
|
|
if not previous_text:
|
|
self._reasoning_ended = False
|
|
|
|
# ── Already past reasoning → everything is content ──
|
|
#
|
|
# We track reasoning state via self._reasoning_ended which is
|
|
# set when we see think-end or a tool-section marker in the
|
|
# model's *generated* text. We do NOT use
|
|
# is_reasoning_end(previous_token_ids) because previous_token_ids
|
|
# includes the entire chat history — on multi-turn conversations
|
|
# it contains think-end tokens from prior assistant messages,
|
|
# which would incorrectly report reasoning as already ended.
|
|
if self._reasoning_ended:
|
|
# Strip any residual think tags that might appear in content
|
|
cleaned = self._strip_think_tags(delta_text)
|
|
if not cleaned:
|
|
return None
|
|
# If tool-calls section markers are present, suppress them
|
|
# from content — the tool parser handles them via current_text
|
|
# re-parsing and does not need them forwarded as content.
|
|
for variant in self._tool_section_start_variants:
|
|
cleaned = cleaned.replace(variant, "")
|
|
return DeltaMessage(content=cleaned) if cleaned else None
|
|
|
|
# ── Check for </think> in this delta ──
|
|
if self._end_token in delta_text:
|
|
end_idx = delta_text.find(self._end_token)
|
|
# Everything before </think> is reasoning (strip <think> if present)
|
|
reasoning = self._strip_think_tags(delta_text[:end_idx])
|
|
# Everything after </think> is content
|
|
content = delta_text[end_idx + len(self._end_token):]
|
|
|
|
self._reasoning_ended = True
|
|
|
|
kwargs: dict = {}
|
|
if reasoning:
|
|
kwargs["reasoning"] = reasoning
|
|
if content:
|
|
kwargs["content"] = content
|
|
return DeltaMessage(**kwargs) if kwargs else None
|
|
|
|
# ── Check for implicit reasoning end via tool section ──
|
|
tool_idx = self._find_tool_section_start(delta_text)
|
|
if tool_idx != -1:
|
|
reasoning = self._strip_think_tags(delta_text[:tool_idx])
|
|
# Do NOT forward the tool section marker as content. The
|
|
# tool parser detects it via current_text re-parsing on its
|
|
# own. Forwarding it causes double-handling and empty content
|
|
# deltas.
|
|
|
|
self._reasoning_ended = True
|
|
|
|
kwargs = {}
|
|
if reasoning:
|
|
kwargs["reasoning"] = reasoning
|
|
return DeltaMessage(**kwargs) if kwargs else None
|
|
|
|
# ── Still in reasoning — strip <think> tag if present ──
|
|
cleaned = self._strip_think_tags(delta_text)
|
|
return DeltaMessage(reasoning=cleaned) if cleaned else None |