diff --git a/tests/entrypoints/openai/responses/test_serving_responses.py b/tests/entrypoints/openai/responses/test_serving_responses.py index 157f7f12f..39429cb9b 100644 --- a/tests/entrypoints/openai/responses/test_serving_responses.py +++ b/tests/entrypoints/openai/responses/test_serving_responses.py @@ -628,6 +628,31 @@ def _identity_increment(event): return event +def _mock_parser_with_reasoning(serving, delta_sequence: list[DeltaMessage]): + """Set up serving.parser so that it returns a mock parser instance + with a reasoning parser that returns the given delta_sequence. + + The mock has reasoning_parser set (truthy) but tool_parser as None, + so the parser's parse_delta enters the reasoning-only branch. + """ + call_count = 0 + + def mock_parse_delta(**kwargs): + nonlocal call_count + if call_count >= len(delta_sequence): + return None + result = delta_sequence[call_count] + call_count += 1 + return result + + mock_parser_instance = MagicMock() + mock_parser_instance.reasoning_parser = MagicMock() # truthy + mock_parser_instance.tool_parser = None + mock_parser_instance.parse_delta = mock_parse_delta + mock_parser_instance.is_reasoning_end = MagicMock(return_value=False) + serving.parser = MagicMock(return_value=mock_parser_instance) + + class TestStreamingReasoningToContentTransition: """Tests for _process_simple_streaming_events reasoning-to-content transition, specifically the fix for mixed deltas that carry both @@ -646,27 +671,13 @@ class TestStreamingReasoningToContentTransition: monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False) serving = _make_serving_instance_with_reasoning() - # Sequence of DeltaMessages the mock reasoning parser will return + # Sequence of DeltaMessages the mock orchestrator will return delta_sequence = [ DeltaMessage(reasoning="thinking..."), DeltaMessage(reasoning=" end", content="hello"), # mixed delta DeltaMessage(content=" world"), ] - call_count = 0 - - def mock_extract_reasoning_streaming(**kwargs): - nonlocal call_count - result = delta_sequence[call_count] - call_count += 1 - return result - - # Mock the reasoning parser on the serving instance - mock_parser = MagicMock() - mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming - mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming - serving.parser = MagicMock() - serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser) - serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser) + _mock_parser_with_reasoning(serving, delta_sequence) # Create contexts for each streaming chunk contexts = [ _make_simple_context_with_output("chunk1", [10]), @@ -734,20 +745,7 @@ class TestStreamingReasoningToContentTransition: DeltaMessage(reasoning="thinking"), DeltaMessage(content="answer"), ] - call_count = 0 - - def mock_extract_reasoning_streaming(**kwargs): - nonlocal call_count - result = delta_sequence[call_count] - call_count += 1 - return result - - mock_parser = MagicMock() - mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming - mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming - serving.parser = MagicMock() - serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser) - serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser) + _mock_parser_with_reasoning(serving, delta_sequence) contexts = [ _make_simple_context_with_output("chunk1", [10]), @@ -809,20 +807,7 @@ class TestStreamingReasoningToContentTransition: DeltaMessage(reasoning="step 1"), DeltaMessage(reasoning=" step 2"), ] - call_count = 0 - - def mock_extract_reasoning_streaming(**kwargs): - nonlocal call_count - result = delta_sequence[call_count] - call_count += 1 - return result - - mock_parser = MagicMock() - mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming - mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming - serving.parser = MagicMock() - serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser) - serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser) + _mock_parser_with_reasoning(serving, delta_sequence) contexts = [ _make_simple_context_with_output("chunk1", [10]), diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index d11a78124..086101c28 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -1339,101 +1339,31 @@ class OpenAIServingResponses(OpenAIServing): current_content_index = 0 current_output_index = 0 current_item_id = "" - reasoning_parser = None - if self.parser and self.parser.reasoning_parser_cls: - reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) - tool_parser = None - if self.parser and self.parser.tool_parser_cls: - tool_parser = self.parser.tool_parser_cls(tokenizer, request.tools) - reasoning_ended = False - tool_call_text_started = False - previous_text = "" - previous_token_ids: list[int] = [] - prompt_is_reasoning_end = None + parser = self.parser(tokenizer, request.tools) if self.parser else None first_delta_sent = False previous_delta_messages: list[DeltaMessage] = [] async for ctx in result_generator: assert isinstance(ctx, SimpleContext) if ctx.last_output is None: continue - if reasoning_parser and prompt_is_reasoning_end is None: - prompt_is_reasoning_end = reasoning_parser.is_reasoning_end( - ctx.last_output.prompt_token_ids - ) if ctx.last_output.outputs: output = ctx.last_output.outputs[0] # finish_reason='error' indicates a retryable error self._raise_if_error(output.finish_reason, request.request_id) delta_text = output.text delta_token_ids = as_list(output.token_ids) - current_text = previous_text + delta_text - current_token_ids = previous_token_ids + delta_token_ids - if reasoning_parser and tool_parser: - if prompt_is_reasoning_end: - reasoning_ended = True - if not reasoning_ended: - delta_message = reasoning_parser.extract_reasoning_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_text, - previous_token_ids=previous_token_ids, - current_token_ids=current_token_ids, - delta_token_ids=delta_token_ids, - ) - if reasoning_parser.is_reasoning_end(delta_token_ids): - reasoning_ended = True - current_token_ids = reasoning_parser.extract_content_ids( - delta_token_ids - ) - if delta_message and delta_message.content: - current_text = delta_message.content - delta_message.content = None - else: - current_text = "" - - if reasoning_ended: - if not tool_call_text_started: - tool_call_text_started = True - previous_text = "" - previous_token_ids = [] - delta_text = current_text - delta_token_ids = current_token_ids - - delta_message = tool_parser.extract_tool_calls_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_text, - previous_token_ids=previous_token_ids, - current_token_ids=current_token_ids, - delta_token_ids=delta_token_ids, - request=request, # type: ignore[arg-type] - ) - elif reasoning_parser: - delta_message = reasoning_parser.extract_reasoning_streaming( - previous_text=previous_text, - current_text=current_text, + if parser: + delta_message = parser.parse_delta( delta_text=delta_text, - previous_token_ids=previous_token_ids, - current_token_ids=current_token_ids, delta_token_ids=delta_token_ids, - ) - elif tool_parser: - delta_message = tool_parser.extract_tool_calls_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_text, - previous_token_ids=previous_token_ids, - current_token_ids=current_token_ids, - delta_token_ids=delta_token_ids, - request=request, # type: ignore[arg-type] + request=request, + prompt_token_ids=ctx.last_output.prompt_token_ids, ) else: delta_message = DeltaMessage( content=output.text, ) - previous_text = current_text - previous_token_ids = current_token_ids if not delta_message: continue if not first_delta_sent: diff --git a/vllm/parser/abstract_parser.py b/vllm/parser/abstract_parser.py index 7e8c236aa..30b4c4ebe 100644 --- a/vllm/parser/abstract_parser.py +++ b/vllm/parser/abstract_parser.py @@ -5,6 +5,7 @@ import contextlib import json from abc import abstractmethod from collections.abc import Sequence +from dataclasses import dataclass, field from functools import cached_property from openai.types.responses import ( @@ -43,6 +44,17 @@ from vllm.utils import random_uuid logger = init_logger(__name__) +@dataclass +class StreamState: + """Mutable state for ``Parser.parse_delta()``. One per stream.""" + + reasoning_ended: bool = False + tool_call_text_started: bool = False + prompt_reasoning_checked: bool = False + previous_text: str = "" + previous_token_ids: list[int] = field(default_factory=list) + + class Parser: """ Abstract Parser class that unifies ReasoningParser and ToolParser into @@ -80,6 +92,7 @@ class Parser: self.model_tokenizer = tokenizer self._reasoning_parser: ReasoningParser | None = None self._tool_parser: ToolParser | None = None + self._stream_state = StreamState() @cached_property def vocab(self) -> dict[str, int]: @@ -291,6 +304,18 @@ class Parser: A DeltaMessage with tool_calls field, or None. """ + @abstractmethod + def parse_delta( + self, + delta_text: str, + delta_token_ids: list[int], + request: ChatCompletionRequest | ResponsesRequest, + prompt_token_ids: list[int] | None = None, + ) -> DeltaMessage | None: + """Parse a single streaming delta, orchestrating reasoning then + tool call extraction via internal stream state. + """ + class DelegatingParser(Parser): """ @@ -524,6 +549,100 @@ class DelegatingParser(Parser): request, ) + def is_reasoning_end(self, input_ids: list[int]) -> bool: + if self._reasoning_parser is None: + return False + return self._reasoning_parser.is_reasoning_end(input_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if self._reasoning_parser is None: + return input_ids + return self._reasoning_parser.extract_content_ids(input_ids) + + def parse_delta( + self, + delta_text: str, + delta_token_ids: list[int], + request: ChatCompletionRequest | ResponsesRequest, + prompt_token_ids: list[int] | None = None, + ) -> DeltaMessage | None: + state = self._stream_state + + if not state.prompt_reasoning_checked and prompt_token_ids is not None: + state.prompt_reasoning_checked = True + if self.is_reasoning_end(prompt_token_ids): + state.reasoning_ended = True + + current_text = state.previous_text + delta_text + current_token_ids = state.previous_token_ids + delta_token_ids + + delta_message: DeltaMessage | None = None + + if self._reasoning_parser and self._tool_parser: + if not state.reasoning_ended: + delta_message = self.extract_reasoning_streaming( + previous_text=state.previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=state.previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + ) + if self.is_reasoning_end(delta_token_ids): + state.reasoning_ended = True + current_token_ids = self.extract_content_ids(delta_token_ids) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + if state.reasoning_ended: + if not state.tool_call_text_started: + state.tool_call_text_started = True + state.previous_text = "" + state.previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = self.extract_tool_calls_streaming( + previous_text=state.previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=state.previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, # type: ignore[arg-type] + ) + + elif self._reasoning_parser: + delta_message = self.extract_reasoning_streaming( + previous_text=state.previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=state.previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + ) + + elif self._tool_parser: + delta_message = self.extract_tool_calls_streaming( + previous_text=state.previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=state.previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, # type: ignore[arg-type] + ) + + else: + delta_message = DeltaMessage(content=delta_text) + + state.previous_text = current_text + state.previous_token_ids = current_token_ids + return delta_message + class _WrappedParser(DelegatingParser): """