[CI] Fix mypy for vllm/reasoning (#35742)
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -23,7 +23,7 @@ class TestGptOssStructuralTagsIntegration:
|
||||
"""Create a mock tokenizer."""
|
||||
tokenizer = Mock()
|
||||
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
|
||||
tokenizer.vocab = {"<|end|>": 6}
|
||||
tokenizer.get_vocab = Mock(return_value={"<|end|>": 6})
|
||||
return tokenizer
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -25,7 +25,7 @@ class TestGptOssReasoningParser:
|
||||
"""Create a mock tokenizer for testing."""
|
||||
tokenizer = Mock()
|
||||
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
|
||||
tokenizer.vocab = {"<|end|>": 6}
|
||||
tokenizer.get_vocab = Mock(return_value={"<|end|>": 6})
|
||||
return tokenizer
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -41,7 +41,6 @@ EXCLUDE = [
|
||||
# TODO: Remove these entries after fixing mypy errors.
|
||||
"vllm/benchmarks",
|
||||
"vllm/config",
|
||||
"vllm/reasoning",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.entrypoints.mcp.tool_server import ToolServer
|
||||
from vllm.logger import init_logger
|
||||
@@ -14,21 +14,10 @@ from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.import_utils import import_from_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
else:
|
||||
ChatCompletionRequest = Any
|
||||
DeltaMessage = Any
|
||||
ResponsesRequest = Any
|
||||
TokenizerLike = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -41,7 +30,7 @@ class ReasoningParser:
|
||||
It is used to extract reasoning content from the model output.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
|
||||
def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
@@ -127,7 +116,7 @@ class ReasoningParser:
|
||||
def extract_reasoning(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | ResponsesRequest,
|
||||
request: "ChatCompletionRequest | ResponsesRequest",
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Extract reasoning content from a complete model-generated string.
|
||||
@@ -136,14 +125,10 @@ class ReasoningParser:
|
||||
available before sending to the client.
|
||||
|
||||
Parameters:
|
||||
model_output: str
|
||||
The model-generated string to extract reasoning content from.
|
||||
|
||||
request: ChatCompletionRequest
|
||||
The request object that was used to generate the model_output.
|
||||
model_output: The model-generated string to extract reasoning content from.
|
||||
request: The request object that was used to generate the model_output.
|
||||
|
||||
Returns:
|
||||
tuple[Optional[str], Optional[str]]
|
||||
A tuple containing the reasoning content and the content.
|
||||
"""
|
||||
|
||||
@@ -156,7 +141,7 @@ class ReasoningParser:
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
) -> "DeltaMessage | None":
|
||||
"""
|
||||
Instance method that should be implemented for extracting reasoning
|
||||
from an incomplete response; for use when handling reasoning calls and
|
||||
|
||||
@@ -4,22 +4,15 @@
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
else:
|
||||
ChatCompletionRequest = Any
|
||||
ResponsesRequest = Any
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
|
||||
class BaseThinkingReasoningParser(ReasoningParser):
|
||||
@@ -58,13 +51,15 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
if not self.start_token or not self.end_token:
|
||||
raise ValueError("start_token and end_token must be defined in subclasses")
|
||||
|
||||
self.start_token_id = self.vocab.get(self.start_token)
|
||||
self.end_token_id = self.vocab.get(self.end_token)
|
||||
if self.start_token_id is None or self.end_token_id is None:
|
||||
start_token_id = self.vocab.get(self.start_token)
|
||||
end_token_id = self.vocab.get(self.end_token)
|
||||
if start_token_id is None or end_token_id is None:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} reasoning parser could not locate "
|
||||
"think start/end tokens in the tokenizer!"
|
||||
)
|
||||
self.start_token_id: int = start_token_id
|
||||
self.end_token_id: int = end_token_id
|
||||
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
start_token_id = self.start_token_id
|
||||
@@ -152,7 +147,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
|
||||
@@ -2,19 +2,21 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
|
||||
from .identity_reasoning_parser import IdentityReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -32,6 +34,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
|
||||
enable_thinking = bool(chat_kwargs.get("enable_thinking", False))
|
||||
thinking = thinking or enable_thinking
|
||||
|
||||
self._parser: ReasoningParser
|
||||
if thinking:
|
||||
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
|
||||
else:
|
||||
@@ -49,7 +52,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
|
||||
return self._parser.extract_content_ids(input_ids)
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
return self._parser.extract_reasoning(model_output, request)
|
||||
|
||||
@@ -61,7 +64,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
) -> "DeltaMessage | None":
|
||||
return self._parser.extract_reasoning_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
|
||||
@@ -2,16 +2,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -46,20 +48,12 @@ class Ernie45ReasoningParser(BaseThinkingReasoningParser):
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.start_token_id = self.vocab.get(self.start_token)
|
||||
self.end_token_id = self.vocab.get(self.end_token)
|
||||
self.response_start_token_id = self.vocab.get(self.response_start_token)
|
||||
self.response_end_token_id = self.vocab.get(self.response_end_token)
|
||||
self.newline_token_id = self.vocab.get(self.newline_token)
|
||||
|
||||
self.parser_token_ids = [self.end_token_id, self.response_end_token_id]
|
||||
|
||||
if self.start_token_id is None or self.end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Ernie45 reasoning parser could not locate think start/end "
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
@@ -144,7 +138,7 @@ class Ernie45ReasoningParser(BaseThinkingReasoningParser):
|
||||
return DeltaMessage(reasoning=delta_text)
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
|
||||
@@ -2,18 +2,20 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.mcp.tool_server import ToolServer
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import parse_chat_output
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
no_func_reaonsing_tag = {
|
||||
@@ -78,7 +80,7 @@ class GptOssReasoningParser(ReasoningParser):
|
||||
self.reasoning_end_token_ids_suffix = self.model_tokenizer.encode("<|message|>")
|
||||
# We also need to check for the <|end|> token to avoid false positives from
|
||||
# previous messages in multi-turn conversations.
|
||||
self.eom_token_id = self.model_tokenizer.vocab["<|end|>"]
|
||||
self.eom_token_id = self.vocab["<|end|>"]
|
||||
self.reasoning_max_num_between_tokens = 20
|
||||
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
@@ -148,7 +150,7 @@ class GptOssReasoningParser(ReasoningParser):
|
||||
def extract_reasoning(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
request: "ChatCompletionRequest | ResponsesRequest",
|
||||
) -> tuple[str | None, str | None]:
|
||||
raise NotImplementedError(
|
||||
"gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501
|
||||
|
||||
@@ -2,17 +2,19 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -53,7 +55,7 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
)
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract the reasoning content & content sections, respectively.
|
||||
If the sequence doesn't match what we expect, i.e., the model generates
|
||||
|
||||
@@ -2,17 +2,19 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -65,8 +67,8 @@ class HunyuanA13BReasoningParser(ReasoningParser):
|
||||
self.fast_think_ids = [14023, 771, 1363, 524, 27963, 397, 27, 9399, 397]
|
||||
|
||||
# when state change, send out all the buffered text in last state
|
||||
self.buffered_text = []
|
||||
self.buffered_ids = []
|
||||
self.buffered_text: list[str] = []
|
||||
self.buffered_ids: list[int] = []
|
||||
|
||||
self.current_state = "reasoning"
|
||||
self.all_states = ["reasoning", "response"]
|
||||
@@ -76,7 +78,7 @@ class HunyuanA13BReasoningParser(ReasoningParser):
|
||||
# this sequence only for the think start, it has two way to start.
|
||||
self.expected_sequence_side = self.think_start_ids_fast
|
||||
self.sequence_index = 0
|
||||
self.token_buffer = []
|
||||
self.token_buffer: list[int] = []
|
||||
self.text_buffer = ""
|
||||
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
@@ -90,7 +92,7 @@ class HunyuanA13BReasoningParser(ReasoningParser):
|
||||
return []
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract the reasoning content & content sections, respectively.
|
||||
If the sequence doesn't match what we expect, i.e., the model generates
|
||||
|
||||
@@ -2,16 +2,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +61,7 @@ class IdentityReasoningParser(ReasoningParser):
|
||||
return None
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
# No reasoning separation: return None for reasoning,
|
||||
# and full model_output as content
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
|
||||
class KimiK2ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
@@ -39,6 +41,7 @@ class KimiK2ReasoningParser(ReasoningParser):
|
||||
thinking = bool(chat_kwargs.get("thinking", True))
|
||||
|
||||
# If thinking is not enabled, use identity parser to fall through
|
||||
self._identity_parser: IdentityReasoningParser | None
|
||||
if not thinking:
|
||||
self._identity_parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
|
||||
else:
|
||||
@@ -62,10 +65,6 @@ class KimiK2ReasoningParser(ReasoningParser):
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
def _is_identity_mode(self) -> bool:
|
||||
"""Check if parser is in identity mode (no reasoning extraction)."""
|
||||
return self._identity_parser is not None
|
||||
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
"""
|
||||
Check if the reasoning content ends in the input_ids.
|
||||
@@ -74,7 +73,7 @@ class KimiK2ReasoningParser(ReasoningParser):
|
||||
1. The end token (</think>)
|
||||
2. The tool section start token (<|tool_calls_section_begin|>)
|
||||
"""
|
||||
if self._is_identity_mode():
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.is_reasoning_end(input_ids)
|
||||
|
||||
start_token_id = self._start_token_id
|
||||
@@ -95,29 +94,32 @@ class KimiK2ReasoningParser(ReasoningParser):
|
||||
return False
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: Sequence[int], delta_ids: Sequence[int]
|
||||
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the reasoning content ends in the input_ids on a decode step.
|
||||
"""
|
||||
if self._is_identity_mode():
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.is_reasoning_end_streaming(
|
||||
input_ids, delta_ids
|
||||
)
|
||||
|
||||
# Materialize iterable for membership checks
|
||||
delta_ids_set = set(delta_ids)
|
||||
|
||||
# Check for explicit end token or implicit tool section start in delta
|
||||
if self._end_token_id in delta_ids:
|
||||
if self._end_token_id in delta_ids_set:
|
||||
return True
|
||||
return (
|
||||
self._tool_section_start_token_id is not None
|
||||
and self._tool_section_start_token_id in delta_ids
|
||||
and self._tool_section_start_token_id in delta_ids_set
|
||||
)
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Extract content token ids from the input_ids.
|
||||
"""
|
||||
if self._is_identity_mode():
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.extract_content_ids(input_ids)
|
||||
|
||||
if self._end_token_id in input_ids:
|
||||
@@ -145,12 +147,12 @@ class KimiK2ReasoningParser(ReasoningParser):
|
||||
return []
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
"""
|
||||
if self._is_identity_mode():
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.extract_reasoning(model_output, request)
|
||||
|
||||
# thinking does not require a think start token but consume it if present
|
||||
@@ -189,7 +191,7 @@ class KimiK2ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Extract reasoning content from a delta message during streaming.
|
||||
"""
|
||||
if self._is_identity_mode():
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.extract_reasoning_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
|
||||
@@ -2,21 +2,20 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -114,6 +113,6 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
return None, "<think>" + model_output
|
||||
|
||||
@@ -3,18 +3,17 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -113,7 +112,7 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
|
||||
return input_ids[:eot_token_index] + input_ids[eot_token_index + 1 :]
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
|
||||
@@ -8,20 +8,15 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import regex as re
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -256,15 +251,15 @@ class Olmo3ReasoningParser(ReasoningParser):
|
||||
def extract_reasoning(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | ResponsesRequest,
|
||||
request: "ChatCompletionRequest | ResponsesRequest",
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract the reasoning content & content sections, respectively.
|
||||
If the sequence doesn't match what we expect, i.e., the model generates
|
||||
something else, all content is considered non-reasoning content.
|
||||
|
||||
Args:
|
||||
model_output (str): Output of the model to be parsed.
|
||||
request (ChatCompletionRequest | ResponsesRequest): Request being
|
||||
model_output: Output of the model to be parsed.
|
||||
request: Request being
|
||||
processed.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -2,16 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
class Qwen3ReasoningParser(BaseThinkingReasoningParser):
|
||||
@@ -34,7 +33,7 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
|
||||
it is stripped before extraction (non-streaming) or skipped (streaming).
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
|
||||
def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
|
||||
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
||||
@@ -53,7 +52,7 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
|
||||
return "</think>"
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
|
||||
@@ -3,17 +3,19 @@
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -37,12 +39,13 @@ class Step3ReasoningParser(ReasoningParser):
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
if self.think_end_token_id is None:
|
||||
think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
if think_end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Step3 reasoning parser could not locate think end "
|
||||
"token in the tokenizer!"
|
||||
)
|
||||
self.think_end_token_id: int = think_end_token_id
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
@@ -82,7 +85,7 @@ class Step3ReasoningParser(ReasoningParser):
|
||||
return DeltaMessage(reasoning=delta_text)
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
|
||||
) -> tuple[str | None, str | None]:
|
||||
# Check if the model output contains the </think> token
|
||||
if self.think_end_token not in model_output:
|
||||
@@ -94,10 +97,7 @@ class Step3ReasoningParser(ReasoningParser):
|
||||
reasoning = model_output[:end_index]
|
||||
|
||||
# Content after </think> token
|
||||
content = model_output[end_index + len(self.think_end_token) :]
|
||||
|
||||
if len(content) == 0:
|
||||
content = None
|
||||
content = model_output[end_index + len(self.think_end_token) :] or None
|
||||
|
||||
return reasoning, content
|
||||
|
||||
|
||||
@@ -2,17 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
|
||||
class Step3p5ReasoningParser(BaseThinkingReasoningParser):
|
||||
"""
|
||||
@@ -50,7 +49,7 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser):
|
||||
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||
) -> bool:
|
||||
# Only examine newly generated tokens; they may contain multiple ids.
|
||||
return self._is_reasoning_end_from_ids(delta_ids)
|
||||
return self._is_reasoning_end_from_ids(tuple(delta_ids))
|
||||
|
||||
def _is_reasoning_end_from_ids(self, input_ids: Sequence[int]) -> bool:
|
||||
# Scan backwards to find the last special token, <think> or </think>.
|
||||
@@ -96,7 +95,7 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser):
|
||||
def extract_reasoning(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest | ResponsesRequest,
|
||||
request: "ChatCompletionRequest | ResponsesRequest",
|
||||
) -> tuple[str | None, str | None]:
|
||||
reasoning, content = super().extract_reasoning(model_output, request)
|
||||
if reasoning is not None:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import functools
|
||||
import json
|
||||
from collections.abc import Collection, Set
|
||||
from collections.abc import Collection, Sequence, Set
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
@@ -348,7 +348,9 @@ class Grok2Tokenizer(TokenizerLike):
|
||||
tokens = self._maybe_truncate(tokens, max_length)
|
||||
return tokens
|
||||
|
||||
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
|
||||
def decode(
|
||||
self, ids: Sequence[int] | int, skip_special_tokens: bool = False
|
||||
) -> str:
|
||||
if isinstance(ids, int):
|
||||
ids = [ids]
|
||||
if skip_special_tokens:
|
||||
@@ -371,7 +373,7 @@ class Grok2Tokenizer(TokenizerLike):
|
||||
return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self, ids: list[int], skip_special_tokens: bool = False
|
||||
self, ids: Sequence[int], skip_special_tokens: bool = False
|
||||
) -> list[str]:
|
||||
tokens = []
|
||||
for token_id in ids:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, cast, overload
|
||||
|
||||
@@ -434,7 +435,9 @@ class MistralTokenizer(TokenizerLike):
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
|
||||
def decode(
|
||||
self, ids: Sequence[int] | int, skip_special_tokens: bool = False
|
||||
) -> str:
|
||||
# TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
|
||||
# is in, directly call self.transformers_tokenizer.decode(...).
|
||||
if isinstance(ids, int):
|
||||
@@ -512,7 +515,7 @@ class MistralTokenizer(TokenizerLike):
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: list[int],
|
||||
ids: Sequence[int],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> list[str]:
|
||||
if not skip_special_tokens:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Protocol, overload
|
||||
|
||||
@@ -116,12 +117,14 @@ class TokenizerLike(Protocol):
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
|
||||
def decode(
|
||||
self, ids: Sequence[int] | int, skip_special_tokens: bool = False
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: list[int],
|
||||
ids: Sequence[int],
|
||||
skip_special_tokens: bool = False,
|
||||
) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user