Instead of always returning False (which broke tool call streaming), use a heuristic: if think-end appears in the token IDs but is followed by more than 3 tokens (chat template wrapping like <|im_end|>, user markers, etc.), it's from a prior turn's prompt and reasoning hasn't started in the current generation. Return False. If think-end is at or near the end, it's from generated tokens and reasoning has ended. Return True.
373 lines
16 KiB
Python
373 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Kimi-K2 Reasoning Parser — MTP-compatible version.
|
|
|
|
Fixes applied over the upstream parser:
|
|
|
|
1. **<think>/</think> tag suppression no longer requires single-token
|
|
deltas.** The original used ``len(delta_token_ids) == 1`` to detect
|
|
and suppress think tags. With MTP speculative decoding, these tokens
|
|
arrive fused with reasoning text, so the guard fails and raw tags
|
|
leak into the reasoning or content output.
|
|
|
|
2. **Text-based detection replaces token-ID-only detection** in
|
|
``extract_reasoning_streaming``. Since ``<think>`` and ``</think>``
|
|
are single tokens, they always appear as complete strings in
|
|
``delta_text`` (the detokenizer never splits a single token across
|
|
deltas). Text-based stripping is therefore safe and MTP-agnostic.
|
|
|
|
3. **Handles ``</think>`` + ``<|tool_calls_section_begin|>`` arriving
|
|
in the same delta** — the reasoning portion is correctly terminated
|
|
and the tool-call content is forwarded so the tool parser can
|
|
detect it on the same or next call.
|
|
|
|
Drop-in replacement: same class name, same interface.
|
|
"""
|
|
|
|
from collections.abc import Iterable, Sequence
|
|
from typing import TYPE_CHECKING
|
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
|
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
|
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
|
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
|
|
|
|
|
class KimiK2ReasoningParser(ReasoningParser):
|
|
"""
|
|
Reasoning parser for Kimi K2 model — MTP-compatible.
|
|
|
|
Uses ``<think>...</think>`` to denote reasoning text. Reasoning
|
|
may also end implicitly when ``<|tool_calls_section_begin|>``
|
|
appears.
|
|
|
|
All detection uses text-based matching so the parser is agnostic
|
|
to how many tokens arrive per streaming step (robust against MTP
|
|
and EAGLE speculative decoding).
|
|
"""
|
|
|
|
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
|
|
super().__init__(tokenizer, *args, **kwargs)
|
|
|
|
if not self.model_tokenizer:
|
|
raise ValueError(
|
|
"The model tokenizer must be passed to the ReasoningParser "
|
|
"constructor during construction."
|
|
)
|
|
|
|
# Check if thinking is disabled via chat_template_kwargs
|
|
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
|
thinking = bool(chat_kwargs.get("thinking", True))
|
|
|
|
# If thinking is not enabled, use identity parser to fall through
|
|
self._identity_parser: IdentityReasoningParser | None
|
|
if not thinking:
|
|
self._identity_parser = IdentityReasoningParser(
|
|
tokenizer, *args, **kwargs
|
|
)
|
|
else:
|
|
self._identity_parser = None
|
|
|
|
# Token definitions
|
|
self._start_token = "<think>"
|
|
self._end_token = "</think>"
|
|
self._tool_section_start_token = "<|tool_calls_section_begin|>"
|
|
|
|
# Also support singular variant for tool section
|
|
self._tool_section_start_variants = [
|
|
"<|tool_calls_section_begin|>",
|
|
"<|tool_call_section_begin|>",
|
|
]
|
|
|
|
# Get token IDs (used by is_reasoning_end for non-streaming,
|
|
# and is_reasoning_end_streaming for delta checks)
|
|
self._start_token_id = self.vocab.get(self._start_token)
|
|
self._end_token_id = self.vocab.get(self._end_token)
|
|
self._tool_section_start_token_id = self.vocab.get(
|
|
self._tool_section_start_token
|
|
)
|
|
|
|
# Collect all tool section start token IDs (for ID-based checks)
|
|
self._tool_section_start_token_ids: set[int] = set()
|
|
for variant in self._tool_section_start_variants:
|
|
tid = self.vocab.get(variant)
|
|
if tid is not None:
|
|
self._tool_section_start_token_ids.add(tid)
|
|
|
|
if self._start_token_id is None or self._end_token_id is None:
|
|
raise RuntimeError(
|
|
"KimiK2ReasoningParser could not locate think start/end "
|
|
"tokens in the tokenizer!"
|
|
)
|
|
|
|
# Streaming state — tracks reasoning within the CURRENT
|
|
# generation only, avoiding false positives from prior turns'
|
|
# </think> tokens that appear in the prompt token IDs.
|
|
self._reasoning_ended: bool = False
|
|
|
|
# ------------------------------------------------------------------
|
|
# Helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _find_tool_section_start(self, text: str) -> int:
|
|
"""Return the index of the earliest tool-section-start marker,
|
|
or -1 if none found."""
|
|
best = -1
|
|
for variant in self._tool_section_start_variants:
|
|
idx = text.find(variant)
|
|
if idx != -1 and (best == -1 or idx < best):
|
|
best = idx
|
|
return best
|
|
|
|
def _strip_think_tags(self, text: str) -> str:
|
|
"""Remove ``<think>`` and ``</think>`` tag text from *text*."""
|
|
return text.replace(self._start_token, "").replace(self._end_token, "")
|
|
|
|
def _strip_tool_section_markers(self, text: str) -> str:
|
|
"""Remove all tool-section start markers from *text*.
|
|
|
|
The tool parser finds these in ``current_text`` independently;
|
|
forwarding them as content causes double-handling.
|
|
"""
|
|
for variant in self._tool_section_start_variants:
|
|
text = text.replace(variant, "")
|
|
return text
|
|
|
|
# ------------------------------------------------------------------
|
|
# Full-sequence methods (these scan all IDs — MTP-safe as-is)
|
|
# ------------------------------------------------------------------
|
|
|
|
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
|
"""Check if reasoning has ended based on the token ID sequence.
|
|
|
|
Scans backward to find the last think-start or think-end token.
|
|
Returns True only if the last relevant token is a think-end or
|
|
a tool-section-start, AND there is no think-start after it.
|
|
|
|
CRITICAL: When called with prompt_token_ids (as the vLLM serving
|
|
layer does), the input contains the full chat history. On
|
|
multi-turn conversations, the prompt ends with tokens from the
|
|
prior assistant message, which may include think-end. However,
|
|
this think-end belongs to the PRIOR generation — the new
|
|
generation will start its own reasoning with think-start.
|
|
|
|
To handle this correctly, we check whether the input ends with
|
|
a complete reasoning block (think-start ... think-end). If the
|
|
last think token is think-end AND it's followed by non-reasoning
|
|
tokens (like tool_call tokens or end-of-sequence), we return
|
|
True. But if the input is just the prompt with no generated
|
|
tokens yet, we return False because the new generation hasn't
|
|
started reasoning yet.
|
|
|
|
The key insight: in the chat template for multi-turn, after the
|
|
last assistant message's think-end, the template adds
|
|
<|im_end|> followed by new user/assistant markers. The
|
|
assistant generation prompt ends with <|im_assistant|> and
|
|
<|im_middle|> — no think tokens. So if we scan backward and
|
|
find think-end but then find prompt-end tokens (not think-start)
|
|
after it, we know reasoning ended in a PRIOR turn, not the
|
|
current one. We return False to let the new generation start
|
|
fresh.
|
|
"""
|
|
if self._identity_parser is not None:
|
|
return self._identity_parser.is_reasoning_end(input_ids)
|
|
|
|
# Scan backward to find the last think-start or think-end
|
|
# or tool-section-start token.
|
|
last_start = -1
|
|
last_end = -1
|
|
last_tool_section = -1
|
|
|
|
for i in range(len(input_ids) - 1, -1, -1):
|
|
if input_ids[i] == self._start_token_id and last_start == -1:
|
|
last_start = i
|
|
if input_ids[i] == self._end_token_id and last_end == -1:
|
|
last_end = i
|
|
if input_ids[i] in self._tool_section_start_token_ids and last_tool_section == -1:
|
|
last_tool_section = i
|
|
# Stop early if we found think-start — it's the boundary
|
|
if last_start != -1:
|
|
break
|
|
|
|
# No think tokens at all — not a reasoning model output
|
|
if last_start == -1 and last_end == -1 and last_tool_section == -1:
|
|
return False
|
|
|
|
# think-start is the last relevant token — reasoning is in progress
|
|
if last_start != -1 and (last_end == -1 or last_start > last_end):
|
|
return False
|
|
|
|
# think-end or tool-section is the last relevant token.
|
|
# This could be from the prompt (prior turn) or from generated
|
|
# tokens. For prompt tokens on multi-turn, the think-end is
|
|
# from a prior assistant message and the new generation hasn't
|
|
# started yet — we should return False.
|
|
#
|
|
# Heuristic: if think-end appears but is followed by more tokens
|
|
# (like <|im_end|>, user markers, etc.), it's from the prompt
|
|
# and reasoning hasn't started in the current generation yet.
|
|
# Return False.
|
|
#
|
|
# If think-end is the very last token or near the end, it's
|
|
# from generated tokens and reasoning has ended. Return True.
|
|
last_relevant = max(last_end, last_tool_section)
|
|
tokens_after = len(input_ids) - 1 - last_relevant
|
|
|
|
# If there are more than a few tokens after the last think-end,
|
|
# those are prompt tokens (chat template wrapping), meaning
|
|
# the think-end is from a prior turn. Return False.
|
|
if tokens_after > 3:
|
|
return False
|
|
|
|
return True
|
|
|
|
def is_reasoning_end_streaming(
|
|
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
|
) -> bool:
|
|
"""Check if reasoning ends in this delta."""
|
|
if self._identity_parser is not None:
|
|
return self._identity_parser.is_reasoning_end_streaming(
|
|
input_ids, delta_ids
|
|
)
|
|
|
|
delta_ids_set = set(delta_ids)
|
|
if self._end_token_id in delta_ids_set:
|
|
return True
|
|
return bool(delta_ids_set & self._tool_section_start_token_ids)
|
|
|
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
|
"""Extract content token IDs (everything after reasoning ends)."""
|
|
if self._identity_parser is not None:
|
|
return self._identity_parser.extract_content_ids(input_ids)
|
|
|
|
if self._end_token_id in input_ids:
|
|
end_idx = (
|
|
len(input_ids) - 1 - input_ids[::-1].index(self._end_token_id)
|
|
)
|
|
if end_idx != -1:
|
|
return input_ids[end_idx + 1:]
|
|
|
|
# Check for implicit reasoning end via tool section
|
|
for tid in self._tool_section_start_token_ids:
|
|
if tid in input_ids:
|
|
tool_idx = (
|
|
len(input_ids) - 1 - input_ids[::-1].index(tid)
|
|
)
|
|
if tool_idx != -1:
|
|
return input_ids[tool_idx:]
|
|
|
|
return []
|
|
|
|
# ------------------------------------------------------------------
|
|
# Non-streaming extraction
|
|
# ------------------------------------------------------------------
|
|
|
|
def extract_reasoning(
|
|
self,
|
|
model_output: str,
|
|
request: "ChatCompletionRequest | ResponsesRequest",
|
|
) -> tuple[str | None, str | None]:
|
|
"""Extract (reasoning, content) from complete model output."""
|
|
if self._identity_parser is not None:
|
|
return self._identity_parser.extract_reasoning(
|
|
model_output, request
|
|
)
|
|
|
|
# Consume <think> at the start if present
|
|
start_idx = model_output.find(self._start_token)
|
|
start_idx = 0 if start_idx != 0 else len(self._start_token)
|
|
|
|
# Look for explicit </think>
|
|
end_idx = model_output.find(self._end_token)
|
|
if end_idx != -1:
|
|
reasoning = model_output[start_idx:end_idx]
|
|
content = model_output[end_idx + len(self._end_token):]
|
|
return reasoning, content or None
|
|
|
|
# Look for implicit reasoning end via tool section
|
|
tool_idx = self._find_tool_section_start(model_output)
|
|
if tool_idx != -1:
|
|
reasoning = model_output[start_idx:tool_idx]
|
|
content = model_output[tool_idx:]
|
|
return reasoning, content or None
|
|
|
|
# Still reasoning (no content yet)
|
|
return model_output[start_idx:], None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Streaming extraction — MTP-compatible
|
|
# ------------------------------------------------------------------
|
|
|
|
def extract_reasoning_streaming(
|
|
self,
|
|
previous_text: str,
|
|
current_text: str,
|
|
delta_text: str,
|
|
previous_token_ids: Sequence[int],
|
|
current_token_ids: Sequence[int],
|
|
delta_token_ids: Sequence[int],
|
|
) -> DeltaMessage | None:
|
|
"""Extract reasoning from a streaming delta.
|
|
|
|
Uses **text-based** detection to strip ``<think>``/``</think>``
|
|
tags. This is safe because these are single tokens — the
|
|
detokenizer always produces them as complete strings, never
|
|
split across deltas. This makes the method agnostic to how
|
|
many tokens arrive per step (MTP-compatible).
|
|
"""
|
|
if self._identity_parser is not None:
|
|
return self._identity_parser.extract_reasoning_streaming(
|
|
previous_text, current_text, delta_text,
|
|
previous_token_ids, current_token_ids, delta_token_ids,
|
|
)
|
|
|
|
# Reset state on new stream — previous_text is empty on the
|
|
# first delta of each generation.
|
|
if not previous_text:
|
|
self._reasoning_ended = False
|
|
|
|
# ── Already past reasoning → everything is content ──
|
|
# Uses our own _reasoning_ended flag instead of scanning
|
|
# previous_token_ids, which may contain </think> from prior
|
|
# assistant turns in the prompt and cause false positives.
|
|
if self._reasoning_ended:
|
|
cleaned = self._strip_tool_section_markers(
|
|
self._strip_think_tags(delta_text)
|
|
)
|
|
return DeltaMessage(content=cleaned) if cleaned else None
|
|
|
|
# ── Check for </think> in this delta ──
|
|
if self._end_token in delta_text:
|
|
self._reasoning_ended = True
|
|
end_idx = delta_text.find(self._end_token)
|
|
reasoning = self._strip_think_tags(delta_text[:end_idx])
|
|
content = self._strip_tool_section_markers(
|
|
delta_text[end_idx + len(self._end_token):]
|
|
)
|
|
|
|
kwargs: dict = {}
|
|
if reasoning:
|
|
kwargs["reasoning"] = reasoning
|
|
if content:
|
|
kwargs["content"] = content
|
|
return DeltaMessage(**kwargs) if kwargs else None
|
|
|
|
# ── Check for implicit reasoning end via tool section ──
|
|
tool_idx = self._find_tool_section_start(delta_text)
|
|
if tool_idx != -1:
|
|
self._reasoning_ended = True
|
|
reasoning = self._strip_think_tags(delta_text[:tool_idx])
|
|
kwargs = {}
|
|
if reasoning:
|
|
kwargs["reasoning"] = reasoning
|
|
return DeltaMessage(**kwargs) if kwargs else None
|
|
|
|
# ── Still in reasoning — strip <think> tag if present ──
|
|
cleaned = self._strip_think_tags(delta_text)
|
|
return DeltaMessage(reasoning=cleaned) if cleaned else None |