Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -13,9 +13,11 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import import_from_path, is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ResponsesRequest)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
else:
|
||||
ChatCompletionRequest = Any
|
||||
@@ -128,8 +130,7 @@ class ReasoningParserManager:
|
||||
if name in cls.reasoning_parsers:
|
||||
return cls.reasoning_parsers[name]
|
||||
|
||||
raise KeyError(
|
||||
f"reasoning helper: '{name}' not found in reasoning_parsers")
|
||||
raise KeyError(f"reasoning helper: '{name}' not found in reasoning_parsers")
|
||||
|
||||
@classmethod
|
||||
def _register_module(
|
||||
@@ -139,8 +140,9 @@ class ReasoningParserManager:
|
||||
force: bool = True,
|
||||
) -> None:
|
||||
if not issubclass(module, ReasoningParser):
|
||||
raise TypeError("module must be subclass of ReasoningParser, "
|
||||
f"but got {type(module)}")
|
||||
raise TypeError(
|
||||
f"module must be subclass of ReasoningParser, but got {type(module)}"
|
||||
)
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
if isinstance(module_name, str):
|
||||
@@ -148,8 +150,9 @@ class ReasoningParserManager:
|
||||
for name in module_name:
|
||||
if not force and name in cls.reasoning_parsers:
|
||||
existed_module = cls.reasoning_parsers[name]
|
||||
raise KeyError(f"{name} is already registered "
|
||||
f"at {existed_module.__module__}")
|
||||
raise KeyError(
|
||||
f"{name} is already registered at {existed_module.__module__}"
|
||||
)
|
||||
cls.reasoning_parsers[name] = module
|
||||
|
||||
@classmethod
|
||||
@@ -168,11 +171,11 @@ class ReasoningParserManager:
|
||||
raise TypeError(f"force must be a boolean, but got {type(force)}")
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str)
|
||||
or is_list_of(name, str)):
|
||||
if not (name is None or isinstance(name, str) or is_list_of(name, str)):
|
||||
raise TypeError(
|
||||
"name must be None, an instance of str, or a sequence of str, "
|
||||
f"but got {type(name)}")
|
||||
f"but got {type(name)}"
|
||||
)
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
@@ -197,6 +200,7 @@ class ReasoningParserManager:
|
||||
try:
|
||||
import_from_path(module_name, plugin_path)
|
||||
except Exception:
|
||||
logger.exception("Failed to load module '%s' from %s.",
|
||||
module_name, plugin_path)
|
||||
logger.exception(
|
||||
"Failed to load module '%s' from %s.", module_name, plugin_path
|
||||
)
|
||||
return
|
||||
|
||||
@@ -5,8 +5,11 @@ from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage, ResponsesRequest)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
@@ -14,11 +17,11 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
class BaseThinkingReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Base class for reasoning parsers that use thinking tokens.
|
||||
|
||||
|
||||
This class provides common functionality for parsers that use start and end
|
||||
tokens to delimit reasoning content (
|
||||
e.g., <think>...</think>, <seed:think>...</seed:think>).
|
||||
|
||||
|
||||
Subclasses must implement the start and end tokens via abstract
|
||||
properties.
|
||||
"""
|
||||
@@ -41,18 +44,19 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction.")
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
if not self.start_token or not self.end_token:
|
||||
raise ValueError(
|
||||
"start_token and end_token must be defined in subclasses")
|
||||
raise ValueError("start_token and end_token must be defined in subclasses")
|
||||
|
||||
self.start_token_id = self.vocab.get(self.start_token)
|
||||
self.end_token_id = self.vocab.get(self.end_token)
|
||||
if self.start_token_id is None or self.end_token_id is None:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} reasoning parser could not locate "
|
||||
"think start/end tokens in the tokenizer!")
|
||||
"think start/end tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.end_token_id in input_ids
|
||||
@@ -64,7 +68,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
if self.end_token_id not in input_ids[:-1]:
|
||||
return []
|
||||
else:
|
||||
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
||||
return input_ids[input_ids.index(self.end_token_id) + 1 :]
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
@@ -81,9 +85,9 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
Uses token IDs for faster processing.
|
||||
"""
|
||||
# Skip single special tokens
|
||||
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
|
||||
self.start_token_id, self.end_token_id
|
||||
]):
|
||||
if len(delta_token_ids) == 1 and (
|
||||
delta_token_ids[0] in [self.start_token_id, self.end_token_id]
|
||||
):
|
||||
return None
|
||||
|
||||
# Check if start token is present in previous or delta.
|
||||
@@ -94,7 +98,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
# extract reasoning content
|
||||
end_index = delta_text.find(self.end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.end_token):]
|
||||
content = delta_text[end_index + len(self.end_token) :]
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning_content,
|
||||
content=content if content else None,
|
||||
@@ -113,9 +117,10 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
# extract reasoning content
|
||||
start_index = delta_text.find(self.start_token)
|
||||
end_index = delta_text.find(self.end_token)
|
||||
reasoning_content = delta_text[start_index +
|
||||
len(self.start_token):end_index]
|
||||
content = delta_text[end_index + len(self.end_token):]
|
||||
reasoning_content = delta_text[
|
||||
start_index + len(self.start_token) : end_index
|
||||
]
|
||||
content = delta_text[end_index + len(self.end_token) :]
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning_content,
|
||||
content=content if content else None,
|
||||
@@ -129,28 +134,27 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: Union[ChatCompletionRequest,
|
||||
ResponsesRequest]
|
||||
self, model_output: str, request: Union[ChatCompletionRequest, ResponsesRequest]
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
|
||||
|
||||
This is the base implementation that works for most models.
|
||||
Subclasses can override this method for specific behavior.
|
||||
"""
|
||||
# Check if the start token is present in the model output, remove it
|
||||
# if it is present.
|
||||
model_output_parts = model_output.partition(self.start_token)
|
||||
model_output = model_output_parts[2] if model_output_parts[
|
||||
1] else model_output_parts[0]
|
||||
model_output = (
|
||||
model_output_parts[2] if model_output_parts[1] else model_output_parts[0]
|
||||
)
|
||||
|
||||
# For models that may not generate start token,
|
||||
# assume the reasoning content is always at the start.
|
||||
if self.end_token not in model_output:
|
||||
return model_output, None
|
||||
else:
|
||||
reasoning_content, _, content = model_output.partition(
|
||||
self.end_token)
|
||||
reasoning_content, _, content = model_output.partition(self.end_token)
|
||||
# If generation stops right after end-of-think, return null content
|
||||
final_content = content or None
|
||||
return reasoning_content, final_content
|
||||
|
||||
@@ -45,14 +45,17 @@ class DeepSeekR1ReasoningParser(BaseThinkingReasoningParser):
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
)
|
||||
if (ret is not None and self.start_token_id not in previous_token_ids
|
||||
and self.start_token_id not in delta_token_ids):
|
||||
if (
|
||||
ret is not None
|
||||
and self.start_token_id not in previous_token_ids
|
||||
and self.start_token_id not in delta_token_ids
|
||||
):
|
||||
if self.end_token_id in delta_token_ids:
|
||||
# end token in delta with more tokens,
|
||||
# extract reasoning content and content
|
||||
end_index = delta_text.find(self.end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.end_token):]
|
||||
content = delta_text[end_index + len(self.end_token) :]
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning_content,
|
||||
content=content if content else None,
|
||||
|
||||
@@ -6,8 +6,7 @@ from typing import Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
@@ -35,17 +34,21 @@ class Glm4MoeModelReasoningParser(ReasoningParser):
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction.")
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.think_start_token_id = self.vocab.get(self.think_start_token)
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
self.assistant_token_id = self.vocab.get(self.assistant_token)
|
||||
if (self.think_start_token_id is None
|
||||
or self.think_end_token_id is None
|
||||
or self.assistant_token_id is None):
|
||||
if (
|
||||
self.think_start_token_id is None
|
||||
or self.think_end_token_id is None
|
||||
or self.assistant_token_id is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Glm4MoeModel reasoning parser could not locate "
|
||||
"think start/end or assistant tokens in the tokenizer!")
|
||||
"think start/end or assistant tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
"""
|
||||
@@ -67,7 +70,7 @@ class Glm4MoeModelReasoningParser(ReasoningParser):
|
||||
if self.think_end_token_id not in input_ids[:-1]:
|
||||
return []
|
||||
else:
|
||||
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
|
||||
return input_ids[input_ids.index(self.think_end_token_id) + 1 :]
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
@@ -87,9 +90,9 @@ class Glm4MoeModelReasoningParser(ReasoningParser):
|
||||
- 'xyz' goes to content
|
||||
"""
|
||||
# Skip single special tokens
|
||||
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
|
||||
self.think_start_token_id, self.think_end_token_id
|
||||
]):
|
||||
if len(delta_token_ids) == 1 and (
|
||||
delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id]
|
||||
):
|
||||
return None
|
||||
|
||||
if self.think_start_token_id in previous_token_ids:
|
||||
@@ -98,9 +101,11 @@ class Glm4MoeModelReasoningParser(ReasoningParser):
|
||||
# extract reasoning content
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
content = delta_text[end_index + len(self.think_end_token) :]
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning_content,
|
||||
content=content if content else None,
|
||||
)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
# <think> in previous, </think> in previous,
|
||||
# reasoning content continues
|
||||
@@ -114,12 +119,14 @@ class Glm4MoeModelReasoningParser(ReasoningParser):
|
||||
# <think> in delta, </think> in delta, extract reasoning content
|
||||
start_index = delta_text.find(self.think_start_token)
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[start_index +
|
||||
len(self.think_start_token
|
||||
):end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
reasoning_content = delta_text[
|
||||
start_index + len(self.think_start_token) : end_index
|
||||
]
|
||||
content = delta_text[end_index + len(self.think_end_token) :]
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning_content,
|
||||
content=content if content else None,
|
||||
)
|
||||
else:
|
||||
# <think> in delta, no </think> in delta,
|
||||
# reasoning content continues
|
||||
@@ -129,7 +136,7 @@ class Glm4MoeModelReasoningParser(ReasoningParser):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
@@ -143,22 +150,24 @@ class Glm4MoeModelReasoningParser(ReasoningParser):
|
||||
"""
|
||||
|
||||
# Check if the model output contains the <think> and </think> tokens.
|
||||
if (self.think_start_token not in model_output
|
||||
or self.think_end_token not in model_output):
|
||||
if (
|
||||
self.think_start_token not in model_output
|
||||
or self.think_end_token not in model_output
|
||||
):
|
||||
return None, model_output
|
||||
# Check if the <think> is present in the model output, remove it
|
||||
# if it is present.
|
||||
model_output_parts = model_output.partition(self.think_start_token)
|
||||
model_output = model_output_parts[2] if model_output_parts[
|
||||
1] else model_output_parts[0]
|
||||
model_output = (
|
||||
model_output_parts[2] if model_output_parts[1] else model_output_parts[0]
|
||||
)
|
||||
# Check if the model output contains the </think> tokens.
|
||||
# If the end token is not found, return the model output as is.
|
||||
if self.think_end_token not in model_output:
|
||||
return None, model_output
|
||||
|
||||
# Extract reasoning content from the model output.
|
||||
reasoning_content, _, content = model_output.partition(
|
||||
self.think_end_token)
|
||||
reasoning_content, _, content = model_output.partition(self.think_end_token)
|
||||
|
||||
final_content = content or None
|
||||
return reasoning_content, final_content
|
||||
|
||||
@@ -7,8 +7,7 @@ from typing import Optional, Union
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.harmony_utils import parse_chat_output
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
@@ -27,7 +26,8 @@ class GptOssReasoningParser(ReasoningParser):
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
self.reasoning_end_token_ids = self.model_tokenizer.encode(
|
||||
"<|start|>assistant<|channel|>final<|message|>")
|
||||
"<|start|>assistant<|channel|>final<|message|>"
|
||||
)
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
end_token_ids = self.reasoning_end_token_ids
|
||||
@@ -35,7 +35,7 @@ class GptOssReasoningParser(ReasoningParser):
|
||||
# Check if the end sequence is present in the input_ids.
|
||||
# We search from the end of input_ids to find the last match.
|
||||
for i in range(len(input_ids) - len(end_token_ids), -1, -1):
|
||||
if input_ids[i:i + len(end_token_ids)] == end_token_ids:
|
||||
if input_ids[i : i + len(end_token_ids)] == end_token_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -54,28 +54,25 @@ class GptOssReasoningParser(ReasoningParser):
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
prev_reasoning, prev_content, _ = parse_chat_output(
|
||||
list(previous_token_ids))
|
||||
cur_reasoning, cur_content, _ = parse_chat_output(
|
||||
list(current_token_ids))
|
||||
prev_reasoning, prev_content, _ = parse_chat_output(list(previous_token_ids))
|
||||
cur_reasoning, cur_content, _ = parse_chat_output(list(current_token_ids))
|
||||
reasoning_delta = None
|
||||
content_delta = None
|
||||
if cur_reasoning is not None:
|
||||
prev_r = prev_reasoning or ""
|
||||
if cur_reasoning.startswith(prev_r):
|
||||
reasoning_delta = cur_reasoning[len(prev_r):] or None
|
||||
reasoning_delta = cur_reasoning[len(prev_r) :] or None
|
||||
else:
|
||||
reasoning_delta = cur_reasoning
|
||||
if cur_content is not None:
|
||||
prev_c = prev_content or ""
|
||||
if cur_content.startswith(prev_c):
|
||||
content_delta = cur_content[len(prev_c):] or None
|
||||
content_delta = cur_content[len(prev_c) :] or None
|
||||
else:
|
||||
content_delta = cur_content
|
||||
if reasoning_delta is None and content_delta is None:
|
||||
return None
|
||||
return DeltaMessage(reasoning_content=reasoning_delta,
|
||||
content=content_delta)
|
||||
return DeltaMessage(reasoning_content=reasoning_delta, content=content_delta)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self,
|
||||
|
||||
@@ -7,8 +7,7 @@ from typing import Optional, Union
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
@@ -34,15 +33,14 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
self.response_start_expr = r"(?:Here's|Here is) my response:"
|
||||
|
||||
self.reasoning_regex = re.compile(
|
||||
rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)",
|
||||
re.DOTALL)
|
||||
rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL
|
||||
)
|
||||
|
||||
self.valid_think_starts = [
|
||||
"Here's my thought process:", "Here is my thought process:"
|
||||
]
|
||||
self.valid_response_starts = [
|
||||
"Here's my response:", "Here is my response:"
|
||||
"Here's my thought process:",
|
||||
"Here is my thought process:",
|
||||
]
|
||||
self.valid_response_starts = ["Here's my response:", "Here is my response:"]
|
||||
|
||||
# Substrings to match for sequence boundaries on raw text
|
||||
self.seq_boundary_end = ":"
|
||||
@@ -50,10 +48,11 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
|
||||
# The longest any thinking / start of response message can be
|
||||
self.longest_think_start = max(
|
||||
len(think_start) for think_start in self.valid_think_starts)
|
||||
len(think_start) for think_start in self.valid_think_starts
|
||||
)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Extract the reasoning content & content sections, respectively.
|
||||
If the sequence doesn't match what we expect, i.e., the model generates
|
||||
@@ -111,24 +110,27 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
DeltaMessage with either reasoning content or content, or None.
|
||||
"""
|
||||
reasoning_content, resp_seq_len, content = self._get_content_sections(
|
||||
current_text)
|
||||
current_text
|
||||
)
|
||||
# Either we haven't finished the start of the reasoning sequence,
|
||||
# or the model is generating something unexpected.
|
||||
if not reasoning_content:
|
||||
delta_message = self._get_delta_message_with_no_reasoning_bounds(
|
||||
current_text, delta_text)
|
||||
current_text, delta_text
|
||||
)
|
||||
# We have a start of reasoning message, but have not yet finished
|
||||
# the start of response sequence.
|
||||
elif not content:
|
||||
delta_message = self._get_delta_message_with_no_response_bounds(
|
||||
current_text, reasoning_content, delta_text)
|
||||
current_text, reasoning_content, delta_text
|
||||
)
|
||||
# We've finished both the start of reasoning and start of response seq.
|
||||
else:
|
||||
# This should never happen since we matched on the response
|
||||
assert resp_seq_len is not None
|
||||
delta_message = self._get_delta_message_with_both_bounds(
|
||||
delta_text, reasoning_content, content, current_text,
|
||||
resp_seq_len)
|
||||
delta_text, reasoning_content, content, current_text, resp_seq_len
|
||||
)
|
||||
if not delta_message.content and not delta_message.reasoning_content:
|
||||
return None
|
||||
return delta_message
|
||||
@@ -139,26 +141,27 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
|
||||
Args:
|
||||
text (str): Text to check for leading substr.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if any of the possible reasoning start seqs match.
|
||||
"""
|
||||
return any(
|
||||
think_start.startswith(text)
|
||||
for think_start in self.valid_think_starts)
|
||||
think_start.startswith(text) for think_start in self.valid_think_starts
|
||||
)
|
||||
|
||||
def _is_response_start_substr(self, text: str) -> bool:
|
||||
"""Check if a text matches one of the possible start response seqs.
|
||||
|
||||
Args:
|
||||
text (str): Text to check for leading substr.
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if any of the possible response start seqs match.
|
||||
"""
|
||||
return any(
|
||||
response_start.startswith(text)
|
||||
for response_start in self.valid_response_starts)
|
||||
for response_start in self.valid_response_starts
|
||||
)
|
||||
|
||||
def _get_delta_message_with_no_reasoning_bounds(
|
||||
self,
|
||||
@@ -177,8 +180,7 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
"""
|
||||
prev_longest_length = len(current_text) - len(delta_text)
|
||||
is_substr = self._is_reasoning_start_substr(current_text)
|
||||
was_substr = self._is_reasoning_start_substr(
|
||||
current_text[:prev_longest_length])
|
||||
was_substr = self._is_reasoning_start_substr(current_text[:prev_longest_length])
|
||||
|
||||
# Check if we just generated something NOT in the special token seq;
|
||||
# if so, add everything that we previously skipped with this delta
|
||||
@@ -220,12 +222,13 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
# content and fully parse it out; we should not pass the : back.
|
||||
ends_with_start_response_seq = any(
|
||||
current_text.endswith(response_start)
|
||||
for response_start in self.valid_response_starts)
|
||||
for response_start in self.valid_response_starts
|
||||
)
|
||||
if reasoning_content is None or ends_with_start_response_seq:
|
||||
return DeltaMessage(reasoning_content=None, content=None)
|
||||
|
||||
# Consider previous / current text only within context of the reasoning
|
||||
previous_text = reasoning_content[:-len(delta_text)]
|
||||
previous_text = reasoning_content[: -len(delta_text)]
|
||||
current_text = reasoning_content
|
||||
|
||||
# We need to be careful about adding unfinished response sequences;
|
||||
@@ -234,12 +237,21 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
delta_idx = delta_text.rfind(self.seq_boundary_start)
|
||||
|
||||
# Check the state of potential start of response substring matches.
|
||||
prev_was_substr = self._is_response_start_substr(
|
||||
previous_text[prev_idx:]) if prev_idx >= 0 else False
|
||||
delta_continues_substr = self._is_response_start_substr(
|
||||
current_text[prev_idx:]) if prev_idx >= 0 else False
|
||||
delta_new_substr = self._is_response_start_substr(
|
||||
delta_text[delta_idx:]) if delta_idx >= 0 else False
|
||||
prev_was_substr = (
|
||||
self._is_response_start_substr(previous_text[prev_idx:])
|
||||
if prev_idx >= 0
|
||||
else False
|
||||
)
|
||||
delta_continues_substr = (
|
||||
self._is_response_start_substr(current_text[prev_idx:])
|
||||
if prev_idx >= 0
|
||||
else False
|
||||
)
|
||||
delta_new_substr = (
|
||||
self._is_response_start_substr(delta_text[delta_idx:])
|
||||
if delta_idx >= 0
|
||||
else False
|
||||
)
|
||||
|
||||
# Delta only contains potential continued response sequence text.
|
||||
if delta_continues_substr:
|
||||
@@ -248,18 +260,17 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
if not prev_was_substr:
|
||||
# Delta may be starting a new response seq but has other text too.
|
||||
if delta_new_substr:
|
||||
return DeltaMessage(reasoning_content=delta_text[:delta_idx],
|
||||
content=None)
|
||||
return DeltaMessage(
|
||||
reasoning_content=delta_text[:delta_idx], content=None
|
||||
)
|
||||
# Normal case for most reasoning text (no potential special seqs).
|
||||
return DeltaMessage(reasoning_content=delta_text, content=None)
|
||||
# The substring that previously seemed to be a potential response
|
||||
# seq wasn't one; we need to add the content to the delta message,
|
||||
# and also slice off the potential response sequence
|
||||
elif delta_new_substr:
|
||||
reasoning_content = previous_text[
|
||||
prev_idx:] + delta_text[:delta_idx]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=None)
|
||||
reasoning_content = previous_text[prev_idx:] + delta_text[:delta_idx]
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=None)
|
||||
# No new substring yet, and we broke our old one; take the whole delta
|
||||
return DeltaMessage(
|
||||
reasoning_content=previous_text[prev_idx:] + delta_text,
|
||||
@@ -288,23 +299,21 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
DeltaMessage: Message containing the parsed content.
|
||||
"""
|
||||
# Always have content; take length to the end
|
||||
delta_content = delta_text[-len(response_content):]
|
||||
reasoning_end_idx = len(delta_text) - (len(response_content) +
|
||||
response_seq_len)
|
||||
delta_content = delta_text[-len(response_content) :]
|
||||
reasoning_end_idx = len(delta_text) - (len(response_content) + response_seq_len)
|
||||
|
||||
if reasoning_end_idx < 0:
|
||||
delta_reasoning_content = None
|
||||
else:
|
||||
# Get the starting offset
|
||||
start_reasoning_content_idx = len(
|
||||
reasoning_content) + response_seq_len + len(
|
||||
response_content) - 1
|
||||
start_reasoning_content_idx = (
|
||||
len(reasoning_content) + response_seq_len + len(response_content) - 1
|
||||
)
|
||||
delta_offset = len(current_text) - len(delta_text)
|
||||
start_offset = start_reasoning_content_idx - delta_offset
|
||||
if start_offset < 0:
|
||||
start_offset = 0
|
||||
delta_reasoning_content = delta_text[
|
||||
start_offset:reasoning_end_idx]
|
||||
delta_reasoning_content = delta_text[start_offset:reasoning_end_idx]
|
||||
|
||||
return DeltaMessage(
|
||||
reasoning_content=delta_reasoning_content,
|
||||
@@ -329,7 +338,8 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
start_reasoning_content = None
|
||||
parsed_content = False
|
||||
delimiter_idxs = [
|
||||
idx for idx, char in enumerate(current_text)
|
||||
idx
|
||||
for idx, char in enumerate(current_text)
|
||||
if char == self.seq_boundary_end
|
||||
]
|
||||
|
||||
@@ -346,17 +356,15 @@ class GraniteReasoningParser(ReasoningParser):
|
||||
# Check to see if the start of response seq if complete
|
||||
elif not parsed_content:
|
||||
for response_start in self.valid_response_starts:
|
||||
if current_chunk[-len(response_start) +
|
||||
1:] == response_start[:-1]:
|
||||
if current_chunk[-len(response_start) + 1 :] == response_start[:-1]:
|
||||
# Mark end of reasoning and start response content
|
||||
# after the start of response sequence.
|
||||
end_reasoning_content = current_chunk_end - len(
|
||||
response_start)
|
||||
end_reasoning_content = current_chunk_end - len(response_start)
|
||||
reasoning_content = current_text[
|
||||
start_reasoning_content:end_reasoning_content]
|
||||
response_content = current_text[current_chunk_end + 1:]
|
||||
return reasoning_content, len(
|
||||
response_start), response_content
|
||||
start_reasoning_content:end_reasoning_content
|
||||
]
|
||||
response_content = current_text[current_chunk_end + 1 :]
|
||||
return reasoning_content, len(response_start), response_content
|
||||
|
||||
if start_reasoning_content and not parsed_content:
|
||||
return current_text[start_reasoning_content:], None, None
|
||||
|
||||
@@ -7,8 +7,7 @@ from typing import Optional, Union
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
@@ -22,16 +21,16 @@ class HunyuanA13BReasoningParser(ReasoningParser):
|
||||
|
||||
HunyuanReasoningParser
|
||||
|
||||
This class implements a reasoning parser specifically designed
|
||||
for the Hunyuan A13B Model. It is responsible for parsing and
|
||||
extracting structured reasoning and answer segments from model
|
||||
This class implements a reasoning parser specifically designed
|
||||
for the Hunyuan A13B Model. It is responsible for parsing and
|
||||
extracting structured reasoning and answer segments from model
|
||||
outputs that follow a specific pattern.
|
||||
|
||||
Key Features:
|
||||
- For non-stream output , Recognizes and extracts reasoning ("think")
|
||||
and answer ("answer") sections from text using regular expressions.
|
||||
- For stream process, it requires a token id sequences to change the
|
||||
reasoning state and other state so it maintains internal state to
|
||||
reasoning state and other state so it maintains internal state to
|
||||
manage parsing across multiple token.
|
||||
|
||||
|
||||
@@ -50,20 +49,19 @@ class HunyuanA13BReasoningParser(ReasoningParser):
|
||||
|
||||
self.full_match_reasoning_regex = re.compile(
|
||||
rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}",
|
||||
re.DOTALL)
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
self.half_match_reasoning_regex = re.compile(
|
||||
rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)",
|
||||
re.DOTALL)
|
||||
rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL
|
||||
)
|
||||
|
||||
self.think_start_ids = [14023, 771, 397]
|
||||
self.think_start_ids_fast = [14023, 771, 1363]
|
||||
self.response_start_ids = [198, 524, 27963, 397, 27, 9399, 397]
|
||||
self.response_start_ids_fast = [524, 27963, 397, 27, 9399, 397]
|
||||
self.response_end_ids = [198, 524, 9399, 29]
|
||||
self.fast_think_ids = [
|
||||
14023, 771, 1363, 524, 27963, 397, 27, 9399, 397
|
||||
]
|
||||
self.fast_think_ids = [14023, 771, 1363, 524, 27963, 397, 27, 9399, 397]
|
||||
|
||||
# when state change, send out all the buffered text in last state
|
||||
self.buffered_text = []
|
||||
@@ -91,7 +89,7 @@ class HunyuanA13BReasoningParser(ReasoningParser):
|
||||
return []
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Extract the reasoning content & content sections, respectively.
|
||||
If the sequence doesn't match what we expect, i.e., the model generates
|
||||
@@ -121,8 +119,7 @@ class HunyuanA13BReasoningParser(ReasoningParser):
|
||||
reasoning_content, response_content = fallback_match[0]
|
||||
|
||||
if response_content.endswith(self.response_end_expr):
|
||||
response_content = response_content[:-len(self.
|
||||
response_end_expr)]
|
||||
response_content = response_content[: -len(self.response_end_expr)]
|
||||
|
||||
if len(reasoning_content) == 0:
|
||||
reasoning_content = None
|
||||
@@ -133,8 +130,9 @@ class HunyuanA13BReasoningParser(ReasoningParser):
|
||||
|
||||
return None, model_output
|
||||
|
||||
def _is_strict_increasing_subsequence(self, subsequence: Sequence[int],
|
||||
sequence: Sequence[int]) -> bool:
|
||||
def _is_strict_increasing_subsequence(
|
||||
self, subsequence: Sequence[int], sequence: Sequence[int]
|
||||
) -> bool:
|
||||
if not subsequence:
|
||||
return False
|
||||
|
||||
@@ -159,27 +157,27 @@ class HunyuanA13BReasoningParser(ReasoningParser):
|
||||
response_start_sequence = self.response_start_ids
|
||||
response_end_sequence = self.response_end_ids
|
||||
|
||||
assert (len(delta_token_ids) == 1)
|
||||
assert len(delta_token_ids) == 1
|
||||
# Process each token in the delta
|
||||
token = delta_token_ids[0]
|
||||
|
||||
def check_token_with_sequence(token):
|
||||
if self.current_state == "idle" or self.current_state == "think":
|
||||
return (token == self.expected_sequence[self.sequence_index]
|
||||
or token == \
|
||||
self.expected_sequence_side[self.sequence_index])
|
||||
return (
|
||||
token == self.expected_sequence[self.sequence_index]
|
||||
or token == self.expected_sequence_side[self.sequence_index]
|
||||
)
|
||||
else:
|
||||
return token == self.expected_sequence[self.sequence_index]
|
||||
|
||||
def check_last_token(token):
|
||||
if self.current_state == "idle" or self.current_state == "think":
|
||||
# only return true if it's judge using a side sequence.
|
||||
if (self.sequence_index - 1 < len(self.expected_sequence_side)
|
||||
and token
|
||||
== self.expected_sequence_side[self.sequence_index -
|
||||
1]):
|
||||
return self.sequence_index == len(
|
||||
self.expected_sequence_side)
|
||||
if (
|
||||
self.sequence_index - 1 < len(self.expected_sequence_side)
|
||||
and token == self.expected_sequence_side[self.sequence_index - 1]
|
||||
):
|
||||
return self.sequence_index == len(self.expected_sequence_side)
|
||||
else:
|
||||
return self.sequence_index == len(self.expected_sequence)
|
||||
else:
|
||||
@@ -227,19 +225,19 @@ class HunyuanA13BReasoningParser(ReasoningParser):
|
||||
|
||||
# Return content based on current state
|
||||
if self.current_state == "think":
|
||||
return DeltaMessage(reasoning_content=buffered_content,
|
||||
content=None)
|
||||
return DeltaMessage(
|
||||
reasoning_content=buffered_content, content=None
|
||||
)
|
||||
else:
|
||||
return DeltaMessage(reasoning_content=None,
|
||||
content=buffered_content)
|
||||
return DeltaMessage(
|
||||
reasoning_content=None, content=buffered_content
|
||||
)
|
||||
else:
|
||||
# No buffered content, send normally
|
||||
if self.current_state == "think":
|
||||
return DeltaMessage(reasoning_content=delta_text,
|
||||
content=None)
|
||||
return DeltaMessage(reasoning_content=delta_text, content=None)
|
||||
else:
|
||||
return DeltaMessage(reasoning_content=None,
|
||||
content=delta_text)
|
||||
return DeltaMessage(reasoning_content=None, content=delta_text)
|
||||
|
||||
# If no content to send in this delta
|
||||
return None
|
||||
|
||||
@@ -5,8 +5,7 @@ from functools import cached_property
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.reasoning.deepseek_r1_reasoning_parser import (
|
||||
DeepSeekR1ReasoningParser)
|
||||
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -23,34 +22,35 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser):
|
||||
|
||||
def __init__(self, tokenizer: MistralTokenizer, *args, **kwargs):
|
||||
if not isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"The tokenizer must be an instance of MistralTokenizer.")
|
||||
raise ValueError("The tokenizer must be an instance of MistralTokenizer.")
|
||||
|
||||
ReasoningParser.__init__(self, tokenizer, *args, **kwargs)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction.")
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.start_token_id = tokenizer.tokenizer.get_control_token(
|
||||
self.start_token)
|
||||
self.end_token_id = tokenizer.tokenizer.get_control_token(
|
||||
self.end_token)
|
||||
self.start_token_id = tokenizer.tokenizer.get_control_token(self.start_token)
|
||||
self.end_token_id = tokenizer.tokenizer.get_control_token(self.end_token)
|
||||
|
||||
if self.start_token_id is None or self.end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Mistral reasoning parser could not locate think start/end "
|
||||
"tokens in the tokenizer!")
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def start_token(self) -> str:
|
||||
"""The token that starts reasoning content."""
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
return SpecialTokens.begin_think
|
||||
|
||||
@cached_property
|
||||
def end_token(self) -> str:
|
||||
"""The token that ends reasoning content."""
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
return SpecialTokens.end_think
|
||||
|
||||
@@ -11,8 +11,11 @@ import regex as re
|
||||
if TYPE_CHECKING:
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage, ResponsesRequest)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
@@ -33,8 +36,7 @@ class Indices:
|
||||
return self.end - self.start
|
||||
|
||||
|
||||
def string_overlap(a: str,
|
||||
b: str) -> tuple[Optional[Indices], Optional[Indices]]:
|
||||
def string_overlap(a: str, b: str) -> tuple[Optional[Indices], Optional[Indices]]:
|
||||
"""
|
||||
Find the longest overlap where the end of string a matches the start
|
||||
of string b.
|
||||
@@ -95,7 +97,7 @@ class Olmo3ReasoningBuffer:
|
||||
self.state = Olmo3ReasoningState.REASONING
|
||||
pretext, self.buffer = (
|
||||
self.buffer[:start_think_idx],
|
||||
self.buffer[start_think_idx + len(self.think_start):],
|
||||
self.buffer[start_think_idx + len(self.think_start) :],
|
||||
)
|
||||
if start_think_idx > 0:
|
||||
# this covers the case there's content before
|
||||
@@ -108,7 +110,7 @@ class Olmo3ReasoningBuffer:
|
||||
self.state = Olmo3ReasoningState.CONTENT
|
||||
pretext, self.buffer = (
|
||||
self.buffer[:end_think_idx],
|
||||
self.buffer[end_think_idx + len(self.think_end):],
|
||||
self.buffer[end_think_idx + len(self.think_end) :],
|
||||
)
|
||||
if end_think_idx > 0:
|
||||
# this covers the case there's content before
|
||||
@@ -153,12 +155,17 @@ class Olmo3ReasoningBuffer:
|
||||
_, overlap_think_end = string_overlap(delta_text, self.think_end)
|
||||
|
||||
partial_overlap_start = overlap_think_start is not None and len(
|
||||
overlap_think_start) < len(self.think_start)
|
||||
overlap_think_start
|
||||
) < len(self.think_start)
|
||||
partial_overlap_end = overlap_think_end is not None and len(
|
||||
overlap_think_end) < len(self.think_end)
|
||||
overlap_think_end
|
||||
) < len(self.think_end)
|
||||
|
||||
if (partial_overlap_start and self.think_start in self.buffer
|
||||
and not partial_overlap_end):
|
||||
if (
|
||||
partial_overlap_start
|
||||
and self.think_start in self.buffer
|
||||
and not partial_overlap_end
|
||||
):
|
||||
# we can only process the buffer if partial overlap
|
||||
# is the last part of think token (thus causing
|
||||
# text_buffer to contain the start of think token)
|
||||
@@ -223,12 +230,15 @@ class Olmo3ReasoningParser(ReasoningParser):
|
||||
# notice that the first think is optional; this allows template to
|
||||
# work in cases when we hardcode a <think> at the beginning of the
|
||||
# reasoning template.
|
||||
reasoning_expr = (rf"^(?:{self.think_start})?(?P<reasoning>.*?)" +
|
||||
rf"{self.think_end}(?P<content>.*)$")
|
||||
reasoning_expr = (
|
||||
rf"^(?:{self.think_start})?(?P<reasoning>.*?)"
|
||||
+ rf"{self.think_end}(?P<content>.*)$"
|
||||
)
|
||||
self.reasoning_regex = re.compile(reasoning_expr, re.DOTALL)
|
||||
|
||||
self.buffer = Olmo3ReasoningBuffer(think_start=self.think_start,
|
||||
think_end=self.think_end)
|
||||
self.buffer = Olmo3ReasoningBuffer(
|
||||
think_start=self.think_start, think_end=self.think_end
|
||||
)
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
text = self.model_tokenizer.decode(input_ids)
|
||||
@@ -281,8 +291,7 @@ class Olmo3ReasoningParser(ReasoningParser):
|
||||
"""Extract content using token ID sequence state machine"""
|
||||
|
||||
delta_message = self.buffer.add_text(delta_text)
|
||||
if (delta_message is None
|
||||
and self.buffer.think_end in self.buffer.buffer):
|
||||
if delta_message is None and self.buffer.think_end in self.buffer.buffer:
|
||||
# this is a bit hacky, but, because of how the buffer is
|
||||
# constructed, if the last delta_text contains characters that
|
||||
# marks the end of thinking tokens, then messages in the buffer
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ResponsesRequest)
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ResponsesRequest
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
|
||||
@@ -32,12 +31,11 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
|
||||
return "</think>"
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: Union[ChatCompletionRequest,
|
||||
ResponsesRequest]
|
||||
self, model_output: str, request: Union[ChatCompletionRequest, ResponsesRequest]
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
|
||||
|
||||
Qwen3 has stricter requirements - it needs both start and end tokens
|
||||
to be present, unlike other models that work with just the end token.
|
||||
|
||||
@@ -50,15 +48,15 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
|
||||
"""
|
||||
|
||||
# Check if the model output contains both <think> and </think> tokens.
|
||||
if (self.start_token not in model_output
|
||||
or self.end_token not in model_output):
|
||||
if self.start_token not in model_output or self.end_token not in model_output:
|
||||
return None, model_output
|
||||
|
||||
# Check if the <think> is present in the model output, remove it
|
||||
# if it is present.
|
||||
model_output_parts = model_output.partition(self.start_token)
|
||||
model_output = model_output_parts[2] if model_output_parts[
|
||||
1] else model_output_parts[0]
|
||||
model_output = (
|
||||
model_output_parts[2] if model_output_parts[1] else model_output_parts[0]
|
||||
)
|
||||
|
||||
# Check if the model output contains the </think> tokens.
|
||||
# If the end token is not found, return the model output as is.
|
||||
|
||||
@@ -10,10 +10,10 @@ class SeedOSSReasoningParser(BaseThinkingReasoningParser):
|
||||
"""
|
||||
Reasoning parser for SeedOSS model.
|
||||
|
||||
The SeedOSS model uses <seed:think>...</seed:think> tokens to
|
||||
denote reasoning content text. This parser extracts
|
||||
The SeedOSS model uses <seed:think>...</seed:think> tokens to
|
||||
denote reasoning content text. This parser extracts
|
||||
the reasoning content from the model output.
|
||||
Similar to DeepSeek R1, it supports cases
|
||||
Similar to DeepSeek R1, it supports cases
|
||||
where the model doesn't generate the start token.
|
||||
"""
|
||||
|
||||
|
||||
@@ -7,8 +7,7 @@ from typing import Optional, Union
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
@@ -20,7 +19,7 @@ class Step3ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for Step3 model.
|
||||
|
||||
The Step3 model uses </think> token to denote the end of reasoning
|
||||
The Step3 model uses </think> token to denote the end of reasoning
|
||||
text. This parser extracts all content before </think> as reasoning content.
|
||||
"""
|
||||
|
||||
@@ -28,19 +27,20 @@ class Step3ReasoningParser(ReasoningParser):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
self.think_end_token = "</think>"
|
||||
|
||||
self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}",
|
||||
re.DOTALL)
|
||||
self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}", re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction.")
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Step3 reasoning parser could not locate think end "
|
||||
"token in the tokenizer!")
|
||||
"token in the tokenizer!"
|
||||
)
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
@@ -60,17 +60,18 @@ class Step3ReasoningParser(ReasoningParser):
|
||||
- 'xyz' goes to content
|
||||
"""
|
||||
# Skip single special token
|
||||
if len(delta_token_ids
|
||||
) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
return None
|
||||
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
# </think> in delta, extract reasoning content and remaining content
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
content = delta_text[end_index + len(self.think_end_token) :]
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning_content,
|
||||
content=content if content else None,
|
||||
)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
# </think> already seen in previous text, everything is content
|
||||
return DeltaMessage(content=delta_text)
|
||||
@@ -79,9 +80,8 @@ class Step3ReasoningParser(ReasoningParser):
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
|
||||
# Check if the model output contains the </think> token
|
||||
if self.think_end_token not in model_output:
|
||||
# If no </think> token, everything is reasoning content
|
||||
@@ -92,7 +92,7 @@ class Step3ReasoningParser(ReasoningParser):
|
||||
reasoning_content = model_output[:end_index]
|
||||
|
||||
# Content after </think> token
|
||||
content = model_output[end_index + len(self.think_end_token):]
|
||||
content = model_output[end_index + len(self.think_end_token) :]
|
||||
|
||||
if len(content) == 0:
|
||||
content = None
|
||||
@@ -106,4 +106,4 @@ class Step3ReasoningParser(ReasoningParser):
|
||||
if self.think_end_token_id not in input_ids[:-1]:
|
||||
return []
|
||||
else:
|
||||
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
|
||||
return input_ids[input_ids.index(self.think_end_token_id) + 1 :]
|
||||
|
||||
Reference in New Issue
Block a user