more speculative decoding fixes
This commit is contained in:
286
kimi_k2_reasoning_parser.py
Normal file
286
kimi_k2_reasoning_parser.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Kimi-K2 Reasoning Parser — MTP-compatible version.
|
||||
|
||||
Fixes applied over the upstream parser:
|
||||
|
||||
1. **<think>/</think> tag suppression no longer requires single-token
|
||||
deltas.** The original used ``len(delta_token_ids) == 1`` to detect
|
||||
and suppress think tags. With MTP speculative decoding, these tokens
|
||||
arrive fused with reasoning text, so the guard fails and raw tags
|
||||
leak into the reasoning or content output.
|
||||
|
||||
2. **Text-based detection replaces token-ID-only detection** in
|
||||
``extract_reasoning_streaming``. Since ``<think>`` and ``</think>``
|
||||
are single tokens, they always appear as complete strings in
|
||||
``delta_text`` (the detokenizer never splits a single token across
|
||||
deltas). Text-based stripping is therefore safe and MTP-agnostic.
|
||||
|
||||
3. **Handles ``</think>`` + ``<|tool_calls_section_begin|>`` arriving
|
||||
in the same delta** — the reasoning portion is correctly terminated
|
||||
and the tool-call content is forwarded so the tool parser can
|
||||
detect it on the same or next call.
|
||||
|
||||
Drop-in replacement: same class name, same interface.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
|
||||
class KimiK2ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for Kimi K2 model — MTP-compatible.
|
||||
|
||||
Uses ``<think>...</think>`` to denote reasoning text. Reasoning
|
||||
may also end implicitly when ``<|tool_calls_section_begin|>``
|
||||
appears.
|
||||
|
||||
All detection uses text-based matching so the parser is agnostic
|
||||
to how many tokens arrive per streaming step (robust against MTP
|
||||
and EAGLE speculative decoding).
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
# Check if thinking is disabled via chat_template_kwargs
|
||||
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
||||
thinking = bool(chat_kwargs.get("thinking", True))
|
||||
|
||||
# If thinking is not enabled, use identity parser to fall through
|
||||
self._identity_parser: IdentityReasoningParser | None
|
||||
if not thinking:
|
||||
self._identity_parser = IdentityReasoningParser(
|
||||
tokenizer, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
self._identity_parser = None
|
||||
|
||||
# Token definitions
|
||||
self._start_token = "<think>"
|
||||
self._end_token = "</think>"
|
||||
self._tool_section_start_token = "<|tool_calls_section_begin|>"
|
||||
|
||||
# Also support singular variant for tool section
|
||||
self._tool_section_start_variants = [
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_section_begin|>",
|
||||
]
|
||||
|
||||
# Get token IDs (used by is_reasoning_end which scans full ID lists)
|
||||
self._start_token_id = self.vocab.get(self._start_token)
|
||||
self._end_token_id = self.vocab.get(self._end_token)
|
||||
self._tool_section_start_token_id = self.vocab.get(
|
||||
self._tool_section_start_token
|
||||
)
|
||||
|
||||
# Collect all tool section start token IDs (for ID-based checks)
|
||||
self._tool_section_start_token_ids: set[int] = set()
|
||||
for variant in self._tool_section_start_variants:
|
||||
tid = self.vocab.get(variant)
|
||||
if tid is not None:
|
||||
self._tool_section_start_token_ids.add(tid)
|
||||
|
||||
if self._start_token_id is None or self._end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"KimiK2ReasoningParser could not locate think start/end "
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _find_tool_section_start(self, text: str) -> int:
|
||||
"""Return the index of the earliest tool-section-start marker,
|
||||
or -1 if none found."""
|
||||
best = -1
|
||||
for variant in self._tool_section_start_variants:
|
||||
idx = text.find(variant)
|
||||
if idx != -1 and (best == -1 or idx < best):
|
||||
best = idx
|
||||
return best
|
||||
|
||||
def _strip_think_tags(self, text: str) -> str:
|
||||
"""Remove ``<think>`` and ``</think>`` tag text from *text*."""
|
||||
return text.replace(self._start_token, "").replace(self._end_token, "")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Full-sequence methods (these scan all IDs — MTP-safe as-is)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
"""Check if reasoning has ended by scanning the full token sequence.
|
||||
|
||||
Reasoning ends when we see either ``</think>`` or a tool-section
|
||||
start token after the last ``<think>``.
|
||||
"""
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.is_reasoning_end(input_ids)
|
||||
|
||||
for i in range(len(input_ids) - 1, -1, -1):
|
||||
if input_ids[i] == self._start_token_id:
|
||||
return False
|
||||
if input_ids[i] == self._end_token_id:
|
||||
return True
|
||||
if input_ids[i] in self._tool_section_start_token_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||
) -> bool:
|
||||
"""Check if reasoning ends in this delta."""
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.is_reasoning_end_streaming(
|
||||
input_ids, delta_ids
|
||||
)
|
||||
|
||||
delta_ids_set = set(delta_ids)
|
||||
if self._end_token_id in delta_ids_set:
|
||||
return True
|
||||
return bool(delta_ids_set & self._tool_section_start_token_ids)
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
"""Extract content token IDs (everything after reasoning ends)."""
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.extract_content_ids(input_ids)
|
||||
|
||||
if self._end_token_id in input_ids:
|
||||
end_idx = (
|
||||
len(input_ids) - 1 - input_ids[::-1].index(self._end_token_id)
|
||||
)
|
||||
if end_idx != -1:
|
||||
return input_ids[end_idx + 1:]
|
||||
|
||||
# Check for implicit reasoning end via tool section
|
||||
for tid in self._tool_section_start_token_ids:
|
||||
if tid in input_ids:
|
||||
tool_idx = (
|
||||
len(input_ids) - 1 - input_ids[::-1].index(tid)
|
||||
)
|
||||
if tool_idx != -1:
|
||||
return input_ids[tool_idx:]
|
||||
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_reasoning(
|
||||
self,
|
||||
model_output: str,
|
||||
request: "ChatCompletionRequest | ResponsesRequest",
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract (reasoning, content) from complete model output."""
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.extract_reasoning(
|
||||
model_output, request
|
||||
)
|
||||
|
||||
# Consume <think> at the start if present
|
||||
start_idx = model_output.find(self._start_token)
|
||||
start_idx = 0 if start_idx != 0 else len(self._start_token)
|
||||
|
||||
# Look for explicit </think>
|
||||
end_idx = model_output.find(self._end_token)
|
||||
if end_idx != -1:
|
||||
reasoning = model_output[start_idx:end_idx]
|
||||
content = model_output[end_idx + len(self._end_token):]
|
||||
return reasoning, content or None
|
||||
|
||||
# Look for implicit reasoning end via tool section
|
||||
tool_idx = self._find_tool_section_start(model_output)
|
||||
if tool_idx != -1:
|
||||
reasoning = model_output[start_idx:tool_idx]
|
||||
content = model_output[tool_idx:]
|
||||
return reasoning, content or None
|
||||
|
||||
# Still reasoning (no content yet)
|
||||
return model_output[start_idx:], None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming extraction — MTP-compatible
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract reasoning from a streaming delta.
|
||||
|
||||
Uses **text-based** detection to strip ``<think>``/``</think>``
|
||||
tags. This is safe because these are single tokens — the
|
||||
detokenizer always produces them as complete strings, never
|
||||
split across deltas. This makes the method agnostic to how
|
||||
many tokens arrive per step (MTP-compatible).
|
||||
"""
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.extract_reasoning_streaming(
|
||||
previous_text, current_text, delta_text,
|
||||
previous_token_ids, current_token_ids, delta_token_ids,
|
||||
)
|
||||
|
||||
# ── Already past reasoning → everything is content ──
|
||||
if self.is_reasoning_end(previous_token_ids):
|
||||
# Strip any residual think tags that might appear in content
|
||||
cleaned = self._strip_think_tags(delta_text)
|
||||
return DeltaMessage(content=cleaned) if cleaned else None
|
||||
|
||||
# ── Check for </think> in this delta ──
|
||||
if self._end_token in delta_text:
|
||||
end_idx = delta_text.find(self._end_token)
|
||||
# Everything before </think> is reasoning (strip <think> if present)
|
||||
reasoning = self._strip_think_tags(delta_text[:end_idx])
|
||||
# Everything after </think> is content
|
||||
content = delta_text[end_idx + len(self._end_token):]
|
||||
|
||||
kwargs: dict = {}
|
||||
if reasoning:
|
||||
kwargs["reasoning"] = reasoning
|
||||
if content:
|
||||
kwargs["content"] = content
|
||||
return DeltaMessage(**kwargs) if kwargs else None
|
||||
|
||||
# ── Check for implicit reasoning end via tool section ──
|
||||
tool_idx = self._find_tool_section_start(delta_text)
|
||||
if tool_idx != -1:
|
||||
reasoning = self._strip_think_tags(delta_text[:tool_idx])
|
||||
# Forward the tool section marker as content so the tool
|
||||
# parser can detect it.
|
||||
content = delta_text[tool_idx:]
|
||||
|
||||
kwargs = {}
|
||||
if reasoning:
|
||||
kwargs["reasoning"] = reasoning
|
||||
if content:
|
||||
kwargs["content"] = content
|
||||
return DeltaMessage(**kwargs) if kwargs else None
|
||||
|
||||
# ── Still in reasoning — strip <think> tag if present ──
|
||||
cleaned = self._strip_think_tags(delta_text)
|
||||
return DeltaMessage(reasoning=cleaned) if cleaned else None
|
||||
Reference in New Issue
Block a user