[Parser] Migrate response api streaming to unified parser (#38755)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
Signed-off-by: Andrew Xia <axia@meta.com>
This commit is contained in:
Flora Feng
2026-04-07 22:09:00 -04:00
committed by GitHub
parent 9ea7d670d8
commit 927975ead8
3 changed files with 153 additions and 119 deletions

View File

@@ -628,6 +628,31 @@ def _identity_increment(event):
return event
def _mock_parser_with_reasoning(serving, delta_sequence: list[DeltaMessage]):
"""Set up serving.parser so that it returns a mock parser instance
with a reasoning parser that returns the given delta_sequence.
The mock has reasoning_parser set (truthy) but tool_parser as None,
so the parser's parse_delta enters the reasoning-only branch.
"""
call_count = 0
def mock_parse_delta(**kwargs):
nonlocal call_count
if call_count >= len(delta_sequence):
return None
result = delta_sequence[call_count]
call_count += 1
return result
mock_parser_instance = MagicMock()
mock_parser_instance.reasoning_parser = MagicMock() # truthy
mock_parser_instance.tool_parser = None
mock_parser_instance.parse_delta = mock_parse_delta
mock_parser_instance.is_reasoning_end = MagicMock(return_value=False)
serving.parser = MagicMock(return_value=mock_parser_instance)
class TestStreamingReasoningToContentTransition:
"""Tests for _process_simple_streaming_events reasoning-to-content
transition, specifically the fix for mixed deltas that carry both
@@ -646,27 +671,13 @@ class TestStreamingReasoningToContentTransition:
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
serving = _make_serving_instance_with_reasoning()
# Sequence of DeltaMessages the mock reasoning parser will return
# Sequence of DeltaMessages the mock orchestrator will return
delta_sequence = [
DeltaMessage(reasoning="thinking..."),
DeltaMessage(reasoning=" end", content="hello"), # mixed delta
DeltaMessage(content=" world"),
]
call_count = 0
def mock_extract_reasoning_streaming(**kwargs):
nonlocal call_count
result = delta_sequence[call_count]
call_count += 1
return result
# Mock the reasoning parser on the serving instance
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
_mock_parser_with_reasoning(serving, delta_sequence)
# Create contexts for each streaming chunk
contexts = [
_make_simple_context_with_output("chunk1", [10]),
@@ -734,20 +745,7 @@ class TestStreamingReasoningToContentTransition:
DeltaMessage(reasoning="thinking"),
DeltaMessage(content="answer"),
]
call_count = 0
def mock_extract_reasoning_streaming(**kwargs):
nonlocal call_count
result = delta_sequence[call_count]
call_count += 1
return result
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
_mock_parser_with_reasoning(serving, delta_sequence)
contexts = [
_make_simple_context_with_output("chunk1", [10]),
@@ -809,20 +807,7 @@ class TestStreamingReasoningToContentTransition:
DeltaMessage(reasoning="step 1"),
DeltaMessage(reasoning=" step 2"),
]
call_count = 0
def mock_extract_reasoning_streaming(**kwargs):
nonlocal call_count
result = delta_sequence[call_count]
call_count += 1
return result
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
_mock_parser_with_reasoning(serving, delta_sequence)
contexts = [
_make_simple_context_with_output("chunk1", [10]),

View File

@@ -1339,101 +1339,31 @@ class OpenAIServingResponses(OpenAIServing):
current_content_index = 0
current_output_index = 0
current_item_id = ""
reasoning_parser = None
if self.parser and self.parser.reasoning_parser_cls:
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
tool_parser = None
if self.parser and self.parser.tool_parser_cls:
tool_parser = self.parser.tool_parser_cls(tokenizer, request.tools)
reasoning_ended = False
tool_call_text_started = False
previous_text = ""
previous_token_ids: list[int] = []
prompt_is_reasoning_end = None
parser = self.parser(tokenizer, request.tools) if self.parser else None
first_delta_sent = False
previous_delta_messages: list[DeltaMessage] = []
async for ctx in result_generator:
assert isinstance(ctx, SimpleContext)
if ctx.last_output is None:
continue
if reasoning_parser and prompt_is_reasoning_end is None:
prompt_is_reasoning_end = reasoning_parser.is_reasoning_end(
ctx.last_output.prompt_token_ids
)
if ctx.last_output.outputs:
output = ctx.last_output.outputs[0]
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request.request_id)
delta_text = output.text
delta_token_ids = as_list(output.token_ids)
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + delta_token_ids
if reasoning_parser and tool_parser:
if prompt_is_reasoning_end:
reasoning_ended = True
if not reasoning_ended:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
if reasoning_parser.is_reasoning_end(delta_token_ids):
reasoning_ended = True
current_token_ids = reasoning_parser.extract_content_ids(
delta_token_ids
)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
if reasoning_ended:
if not tool_call_text_started:
tool_call_text_started = True
previous_text = ""
previous_token_ids = []
delta_text = current_text
delta_token_ids = current_token_ids
delta_message = tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request, # type: ignore[arg-type]
)
elif reasoning_parser:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text=previous_text,
current_text=current_text,
if parser:
delta_message = parser.parse_delta(
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
elif tool_parser:
delta_message = tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request, # type: ignore[arg-type]
request=request,
prompt_token_ids=ctx.last_output.prompt_token_ids,
)
else:
delta_message = DeltaMessage(
content=output.text,
)
previous_text = current_text
previous_token_ids = current_token_ids
if not delta_message:
continue
if not first_delta_sent:

View File

@@ -5,6 +5,7 @@ import contextlib
import json
from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass, field
from functools import cached_property
from openai.types.responses import (
@@ -43,6 +44,17 @@ from vllm.utils import random_uuid
logger = init_logger(__name__)
@dataclass
class StreamState:
"""Mutable state for ``Parser.parse_delta()``. One per stream."""
reasoning_ended: bool = False
tool_call_text_started: bool = False
prompt_reasoning_checked: bool = False
previous_text: str = ""
previous_token_ids: list[int] = field(default_factory=list)
class Parser:
"""
Abstract Parser class that unifies ReasoningParser and ToolParser into
@@ -80,6 +92,7 @@ class Parser:
self.model_tokenizer = tokenizer
self._reasoning_parser: ReasoningParser | None = None
self._tool_parser: ToolParser | None = None
self._stream_state = StreamState()
@cached_property
def vocab(self) -> dict[str, int]:
@@ -291,6 +304,18 @@ class Parser:
A DeltaMessage with tool_calls field, or None.
"""
@abstractmethod
def parse_delta(
self,
delta_text: str,
delta_token_ids: list[int],
request: ChatCompletionRequest | ResponsesRequest,
prompt_token_ids: list[int] | None = None,
) -> DeltaMessage | None:
"""Parse a single streaming delta, orchestrating reasoning then
tool call extraction via internal stream state.
"""
class DelegatingParser(Parser):
"""
@@ -524,6 +549,100 @@ class DelegatingParser(Parser):
request,
)
def is_reasoning_end(self, input_ids: list[int]) -> bool:
if self._reasoning_parser is None:
return False
return self._reasoning_parser.is_reasoning_end(input_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
if self._reasoning_parser is None:
return input_ids
return self._reasoning_parser.extract_content_ids(input_ids)
def parse_delta(
self,
delta_text: str,
delta_token_ids: list[int],
request: ChatCompletionRequest | ResponsesRequest,
prompt_token_ids: list[int] | None = None,
) -> DeltaMessage | None:
state = self._stream_state
if not state.prompt_reasoning_checked and prompt_token_ids is not None:
state.prompt_reasoning_checked = True
if self.is_reasoning_end(prompt_token_ids):
state.reasoning_ended = True
current_text = state.previous_text + delta_text
current_token_ids = state.previous_token_ids + delta_token_ids
delta_message: DeltaMessage | None = None
if self._reasoning_parser and self._tool_parser:
if not state.reasoning_ended:
delta_message = self.extract_reasoning_streaming(
previous_text=state.previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=state.previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
if self.is_reasoning_end(delta_token_ids):
state.reasoning_ended = True
current_token_ids = self.extract_content_ids(delta_token_ids)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
if state.reasoning_ended:
if not state.tool_call_text_started:
state.tool_call_text_started = True
state.previous_text = ""
state.previous_token_ids = []
delta_text = current_text
delta_token_ids = current_token_ids
delta_message = self.extract_tool_calls_streaming(
previous_text=state.previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=state.previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request, # type: ignore[arg-type]
)
elif self._reasoning_parser:
delta_message = self.extract_reasoning_streaming(
previous_text=state.previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=state.previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
elif self._tool_parser:
delta_message = self.extract_tool_calls_streaming(
previous_text=state.previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=state.previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request, # type: ignore[arg-type]
)
else:
delta_message = DeltaMessage(content=delta_text)
state.previous_text = current_text
state.previous_token_ids = current_token_ids
return delta_message
class _WrappedParser(DelegatingParser):
"""