[Misc] Provide a DeepSeek ReasoningParser with thinking enabled by default (#33221)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey
2026-01-28 21:16:53 +08:00
committed by GitHub
parent 2e8de86777
commit 8e5e40daf4
6 changed files with 25 additions and 192 deletions

View File

@@ -6,7 +6,9 @@ from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.reasoning.holo2_reasoning_parser import Holo2ReasoningParser
from vllm.reasoning.deepseek_v3_reasoning_parser import (
DeepSeekV3ReasoningWithThinkingParser as Holo2ReasoningParser,
)
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
REASONING_MODEL_NAME = "HCompany/Holo2-4B"

View File

@@ -33,8 +33,8 @@ _REASONING_PARSERS_TO_REGISTER = {
"Ernie45ReasoningParser",
),
"glm45": (
"glm4_moe_reasoning_parser",
"Glm4MoeModelReasoningParser",
"deepseek_v3_reasoning_parser",
"DeepSeekV3ReasoningWithThinkingParser",
),
"openai_gptoss": (
"gptoss_reasoning_parser",
@@ -45,16 +45,16 @@ _REASONING_PARSERS_TO_REGISTER = {
"GraniteReasoningParser",
),
"holo2": (
"holo2_reasoning_parser",
"Holo2ReasoningParser",
"deepseek_v3_reasoning_parser",
"DeepSeekV3ReasoningWithThinkingParser",
),
"hunyuan_a13b": (
"hunyuan_a13b_reasoning_parser",
"HunyuanA13BReasoningParser",
),
"kimi_k2": (
"kimi_k2_reasoning_parser",
"KimiK2ReasoningParser",
"deepseek_v3_reasoning_parser",
"DeepSeekV3ReasoningWithThinkingParser",
),
"minimax_m2": (
"minimax_m2_reasoning_parser",

View File

@@ -70,3 +70,19 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
current_token_ids,
delta_token_ids,
)
class DeepSeekV3ReasoningWithThinkingParser(DeepSeekV3ReasoningParser):
"""
DeepSeekV3ReasoningParser that defaults to thinking mode.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
thinking = chat_kwargs.get("thinking", None)
enable_thinking = chat_kwargs.get("enable_thinking", None)
if thinking is None and enable_thinking is None:
chat_kwargs["thinking"] = True
chat_kwargs["enable_thinking"] = True
kwargs["chat_template_kwargs"] = chat_kwargs
super().__init__(tokenizer, *args, **kwargs)

View File

@@ -1,13 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.reasoning.holo2_reasoning_parser import Holo2ReasoningParser
class Glm4MoeModelReasoningParser(Holo2ReasoningParser):
"""
Reasoning parser for the Glm4MoeModel model,which inherits from
`Holo2ReasoningParser`.
"""
pass

View File

@@ -1,92 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.logger import init_logger
from vllm.reasoning import (
ReasoningParser,
)
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Holo2ReasoningParser(ReasoningParser):
"""
Reasoning parser for the Holo2 models which are based on Qwen3.
The Holo2 model uses <think>...</think> tokens to denote reasoning text but <think>
is part of the chat template. This parser extracts the reasoning content until
</think> in the model's output.
The model provides a switch to enable or disable reasoning
output via the 'thinking=False' parameter.
Chat template args:
- thinking: Whether to enable reasoning output (default: True)
Parsing rules on model output:
- thinking == False
-> Model output is treated as purely the content |content|
- thinking == True
-> Model output is |reasoning_content|</think>|content|
"""
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
# Deepseek V3 and Holo2 are similar. However, Holo2 models think by default.
# this parser without user specified chat template args is initiated once for
# all requests in the structured output manager. So it is important that without
# user specified chat template args, the default thinking is True.
thinking = bool(chat_kwargs.get("thinking", True))
enable_thinking = bool(chat_kwargs.get("enable_thinking", True))
thinking = thinking and enable_thinking
if thinking:
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
else:
self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return self._parser.is_reasoning_end(input_ids)
def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool:
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
return self._parser.extract_content_ids(input_ids)
def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[str | None, str | None]:
return self._parser.extract_reasoning(model_output, request)
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
return self._parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)

View File

@@ -1,80 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .identity_reasoning_parser import IdentityReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
else:
ChatCompletionRequest = Any
logger = init_logger(__name__)
class KimiK2ReasoningParser(ReasoningParser):
"""
Kimi K2 parser that delegates to either DeepSeekR1ReasoningParser or
IdentityReasoningParser based on `thinking` and `separate_reasoning`.
Unlike DeepSeekV3ReasoningParser which defaults to NOT thinking,
KimiK2ReasoningParser defaults to thinking mode (uses DeepSeekR1ReasoningParser).
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {}
# Key difference: default to True instead of False
thinking = bool(chat_kwargs.pop("thinking", True))
if thinking:
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
else:
self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return self._parser.is_reasoning_end(input_ids)
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
) -> bool:
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
return self._parser.extract_content_ids(input_ids)
def extract_reasoning(
self, model_output: str, request: "ChatCompletionRequest"
) -> tuple[str | None, str | None]:
return self._parser.extract_reasoning(model_output, request)
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
return self._parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)