[Frontend] Support reasoning content for deepseek r1 (#12473)

Signed-off-by: Ce Gao <cegao@tensorchord.ai>
Co-authored-by: Rafael Vasquez <rafvasq21@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
Ce Gao
2025-01-29 11:38:08 +08:00
committed by GitHub
parent fbb5bd4cef
commit a7e3eba66f
16 changed files with 977 additions and 5 deletions

View File

@@ -61,6 +61,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeRequest,
TokenizeResponse,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@@ -771,6 +772,8 @@ async def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
enable_reasoning=args.enable_reasoning,
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.runner_type == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
@@ -844,6 +847,13 @@ async def run_server(args, **uvicorn_kwargs) -> None:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valid_tool_parses)} }})")
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
if args.enable_reasoning \
and args.reasoning_parser not in valid_reasoning_parses:
raise KeyError(
f"invalid reasoning parser: {args.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204

View File

@@ -12,6 +12,7 @@ from typing import List, Optional, Sequence, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@@ -208,6 +209,23 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False,
help="Enable auto tool choice for supported models. Use "
"``--tool-call-parser`` to specify which parser to use.")
parser.add_argument(
"--enable-reasoning",
action="store_true",
default=False,
help="Whether to enable reasoning_content for the model. "
"If enabled, the model will be able to generate reasoning content.")
valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys()
parser.add_argument(
"--reasoning-parser",
type=str,
metavar="{" + ",".join(valid_reasoning_parsers) + "}",
default=None,
help=
"Select the reasoning parser depending on the model that you're using."
" This is used to parse the reasoning content into OpenAI API "
"format. Required for ``--enable-reasoning``.")
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument(
@@ -267,6 +285,18 @@ def validate_parsed_serve_args(args: argparse.Namespace):
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
# Enable reasoning needs a reasoning parser to be valid
if args.enable_reasoning and not args.reasoning_parser:
raise TypeError("Error: --enable-reasoning requires "
"--reasoning-parser")
# Ref https://api-docs.deepseek.com/guides/reasoning_model
# tool call and reasoning cannot be enabled at the same time.
if args.enable_auto_tool_choice and args.enable_reasoning:
raise TypeError(
"Error: --enable-auto-tool-choice and "
"--enable-reasoning cannot be enabled at the same time")
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(

View File

@@ -1202,6 +1202,7 @@ class ExtractedToolCallInformation(BaseModel):
class ChatMessage(OpenAIBaseModel):
role: str
reasoning_content: Optional[str] = None
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
@@ -1243,6 +1244,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: List[DeltaToolCall] = Field(default_factory=list)

View File

@@ -0,0 +1,6 @@
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
__all__ = [
"ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser"
]

View File

@@ -0,0 +1,158 @@
import os
from functools import cached_property
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import import_from_path, is_list_of
logger = init_logger(__name__)
class ReasoningParser:
"""
Abstract reasoning parser class that should not be used directly.
Provided and methods should be used in derived classes.
It is used to extract reasoning content from the model output.
"""
def __init__(self, tokenizer: AnyTokenizer):
self.model_tokenizer = tokenizer
@cached_property
def vocab(self) -> Dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> Tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from a complete model-generated string.
Used for non-streaming responses where we have the entire model response
available before sending to the client.
Parameters:
model_output: str
The model-generated string to extract reasoning content from.
request: ChatCompletionRequest
The request object that was used to generate the model_output.
Returns:
Tuple[Optional[str], Optional[str]]
A tuple containing the reasoning content and the content.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_calls "
"has not been implemented!")
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Instance method that should be implemented for extracting reasoning
from an incomplete response; for use when handling reasoning calls and
streaming. Has to be an instance method because it requires state -
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!")
class ReasoningParserManager:
reasoning_parsers: Dict[str, Type] = {}
@classmethod
def get_reasoning_parser(cls, name) -> Type:
"""
Get reasoning parser by name which is registered by `register_module`.
Raise a KeyError exception if the name is not registered.
"""
if name in cls.reasoning_parsers:
return cls.reasoning_parsers[name]
raise KeyError(f"reasoning helper: '{name}' not found in "
"reasoning_parsers")
@classmethod
def _register_module(cls,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = True) -> None:
if not issubclass(module, ReasoningParser):
raise TypeError("module must be subclass of ReasoningParser, "
f"but got {type(module)}")
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in cls.reasoning_parsers:
existed_module = cls.reasoning_parsers[name]
raise KeyError(f"{name} is already registered "
f"at {existed_module.__module__}")
cls.reasoning_parsers[name] = module
@classmethod
def register_module(
cls,
name: Optional[Union[str, List[str]]] = None,
force: bool = True,
module: Union[Type, None] = None) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
raise TypeError(f"force must be a boolean, but got {type(force)}")
# raise the error ahead of time
if not (name is None or isinstance(name, str)
or is_list_of(name, str)):
raise TypeError(
"name must be None, an instance of str, or a sequence of str, "
f"but got {type(name)}")
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
cls._register_module(module=module, module_name=name, force=force)
return module
return _register
@classmethod
def import_reasoning_parser(cls, plugin_path: str) -> None:
"""
Import a user-defined reasoning parser by the path
of the reasoning parser define file.
"""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
try:
import_from_path(module_name, plugin_path)
except Exception:
logger.exception("Failed to load module '%s' from %s.",
module_name, plugin_path)
return

View File

@@ -0,0 +1,133 @@
import re
from typing import Optional, Sequence, Tuple, Union
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger
logger = init_logger(__name__)
@ReasoningParserManager.register_module("deepseek_r1")
class DeepSeekR1ReasoningParser(ReasoningParser):
"""
Reasoning parser for DeepSeek R1 model.
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
text. This parser extracts the reasoning content from the model output.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.think_start_token = "<think>"
self.think_end_token = "</think>"
self.reasoning_regex = re.compile(
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
self.think_start_token_id = self.vocab.get(self.think_start_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if (self.think_start_token_id is None
or self.think_end_token_id is None):
raise RuntimeError(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
self.think_start_token_id, self.think_end_token_id
]):
return None
if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
# <think> in previous, </think> in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
else:
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids:
logger.info(delta_text)
if self.think_end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[start_index +
len(self.think_start_token
):end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
else:
# <think> in delta, no </think> in delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
else:
# No <think> in previous or delta, reasoning content continues.
return DeltaMessage(content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> Tuple[Optional[str], Optional[str]]:
# Check if the model output contains the <think> tokens.
if (self.think_start_token not in model_output
or self.think_end_token not in model_output):
return None, model_output
else:
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]
# Remove the reasoning content from the model output
# Although deepseek's <think> token is always at the
# beginning of the line, we cannot guarantee that the
# other models will follow this convention.
# Therefore, we need to add :start_index.
start_index = model_output.find(self.think_start_token)
if start_index != -1:
end_index = start_index + len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
model_output = model_output[:start_index] + \
model_output[end_index:]
if len(model_output) == 0:
return reasoning_content, None
return reasoning_content, model_output

View File

@@ -21,6 +21,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
@@ -47,6 +49,8 @@ class OpenAIServingChat(OpenAIServing):
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
return_tokens_as_token_ids: bool = False,
enable_reasoning: bool = False,
reasoning_parser: Optional[str] = None,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False,
@@ -69,6 +73,18 @@ class OpenAIServingChat(OpenAIServing):
" the parallel_tool_calls client option is preset for "
"compatibility reasons, it will be ignored.")
self.enable_reasoning: bool = enable_reasoning
self.reasoning_parser: Optional[Callable[[AnyTokenizer],
ReasoningParser]] = None
if self.enable_reasoning:
try:
self.reasoning_parser = (
ReasoningParserManager.get_reasoning_parser(
reasoning_parser))
except Exception as e:
raise TypeError("Error: --enable-reasoning requires "
f"reasoning_parser:'{reasoning_parser}' "
"which has not been registered") from e
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools:
try:
@@ -285,14 +301,35 @@ class OpenAIServingChat(OpenAIServing):
not tool_choice_function_name
and self._should_stream_with_auto_tool_parsing(request))
should_stream_with_reasoning_parsing = (
self._should_stream_with_reasoning_parsing(request))
all_previous_token_ids: Optional[List[List[int]]]
if tool_choice_auto:
# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
if tool_choice_auto or should_stream_with_reasoning_parsing:
# These are only required in "auto" tool choice case
previous_texts = [""] * num_choices
all_previous_token_ids = [[]] * num_choices
else:
previous_texts, all_previous_token_ids = None, None
try:
# There is no need to check if the reasoning_parser is None
# because the should_stream_with_reasoning_parsing check
# already ensures that the reasoning_parser is not None.
# but the pre-commit hook requires it.
if should_stream_with_reasoning_parsing and \
self.reasoning_parser is not None:
reasoning_parser = self.reasoning_parser(tokenizer)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
return
# Prepare the tool parser if it's needed
try:
if tool_choice_auto and self.tool_parser:
@@ -456,6 +493,32 @@ class OpenAIServingChat(OpenAIServing):
# update the previous values for the next iteration
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
# reasoning_content cannot be enabled with tool_choice.
# If it is, the tool_choice will be used instead.
elif self.enable_reasoning:
# handle reasoning_content delta
assert reasoning_parser is not None
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + list(
output.token_ids)
delta_message = (reasoning_parser.
extract_reasoning_content_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output.token_ids,
))
# update the previous values for the next iteration
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
# handle streaming just a content delta
else:
@@ -642,17 +705,38 @@ class OpenAIServingChat(OpenAIServing):
else:
logprobs = None
should_stream_with_reasoning_parsing = (
self._should_stream_with_reasoning_parsing(request))
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called = False
if should_stream_with_reasoning_parsing and \
self.reasoning_parser is not None:
try:
reasoning_parser = self.reasoning_parser(tokenizer)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
reasoning_content, content = (
reasoning_parser.extract_reasoning_content(
output.text, request=request))
if reasoning_content:
message = ChatMessage(role=role,
content=content,
reasoning_content=reasoning_content)
else:
message = ChatMessage(role=role, content=output.text)
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
if (not self.enable_auto_tools
or not self.tool_parser) and not isinstance(
request.tool_choice,
ChatCompletionNamedToolChoiceParam):
elif (not self.enable_auto_tools
or not self.tool_parser) and not isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam):
message = ChatMessage(role=role, content=output.text)
# if the request uses tools and specified a tool choice
@@ -835,6 +919,17 @@ class OpenAIServingChat(OpenAIServing):
return (request.tools and self.tool_parser and self.enable_auto_tools
and request.tool_choice in ['auto', None])
def _should_stream_with_reasoning_parsing(self,
request: ChatCompletionRequest):
"""
Utility function to check if streamed tokens should go through the
reasoning parser that was configured.
We only want to do this IF reasoning is enabled and a reasoning
parser is configured.
"""
return self.enable_reasoning and self.reasoning_parser is not None
def _should_check_for_unstreamed_tool_arg_tokens(
self,
delta_message: Optional[DeltaMessage],