[Parser] Migrate response api streaming to unified parser (#38755)
Signed-off-by: sfeng33 <4florafeng@gmail.com> Signed-off-by: Andrew Xia <axia@meta.com>
This commit is contained in:
@@ -628,6 +628,31 @@ def _identity_increment(event):
|
||||
return event
|
||||
|
||||
|
||||
def _mock_parser_with_reasoning(serving, delta_sequence: list[DeltaMessage]):
|
||||
"""Set up serving.parser so that it returns a mock parser instance
|
||||
with a reasoning parser that returns the given delta_sequence.
|
||||
|
||||
The mock has reasoning_parser set (truthy) but tool_parser as None,
|
||||
so the parser's parse_delta enters the reasoning-only branch.
|
||||
"""
|
||||
call_count = 0
|
||||
|
||||
def mock_parse_delta(**kwargs):
|
||||
nonlocal call_count
|
||||
if call_count >= len(delta_sequence):
|
||||
return None
|
||||
result = delta_sequence[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
mock_parser_instance = MagicMock()
|
||||
mock_parser_instance.reasoning_parser = MagicMock() # truthy
|
||||
mock_parser_instance.tool_parser = None
|
||||
mock_parser_instance.parse_delta = mock_parse_delta
|
||||
mock_parser_instance.is_reasoning_end = MagicMock(return_value=False)
|
||||
serving.parser = MagicMock(return_value=mock_parser_instance)
|
||||
|
||||
|
||||
class TestStreamingReasoningToContentTransition:
|
||||
"""Tests for _process_simple_streaming_events reasoning-to-content
|
||||
transition, specifically the fix for mixed deltas that carry both
|
||||
@@ -646,27 +671,13 @@ class TestStreamingReasoningToContentTransition:
|
||||
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
|
||||
serving = _make_serving_instance_with_reasoning()
|
||||
|
||||
# Sequence of DeltaMessages the mock reasoning parser will return
|
||||
# Sequence of DeltaMessages the mock orchestrator will return
|
||||
delta_sequence = [
|
||||
DeltaMessage(reasoning="thinking..."),
|
||||
DeltaMessage(reasoning=" end", content="hello"), # mixed delta
|
||||
DeltaMessage(content=" world"),
|
||||
]
|
||||
call_count = 0
|
||||
|
||||
def mock_extract_reasoning_streaming(**kwargs):
|
||||
nonlocal call_count
|
||||
result = delta_sequence[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
# Mock the reasoning parser on the serving instance
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
|
||||
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
|
||||
serving.parser = MagicMock()
|
||||
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
|
||||
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
|
||||
_mock_parser_with_reasoning(serving, delta_sequence)
|
||||
# Create contexts for each streaming chunk
|
||||
contexts = [
|
||||
_make_simple_context_with_output("chunk1", [10]),
|
||||
@@ -734,20 +745,7 @@ class TestStreamingReasoningToContentTransition:
|
||||
DeltaMessage(reasoning="thinking"),
|
||||
DeltaMessage(content="answer"),
|
||||
]
|
||||
call_count = 0
|
||||
|
||||
def mock_extract_reasoning_streaming(**kwargs):
|
||||
nonlocal call_count
|
||||
result = delta_sequence[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
|
||||
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
|
||||
serving.parser = MagicMock()
|
||||
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
|
||||
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
|
||||
_mock_parser_with_reasoning(serving, delta_sequence)
|
||||
|
||||
contexts = [
|
||||
_make_simple_context_with_output("chunk1", [10]),
|
||||
@@ -809,20 +807,7 @@ class TestStreamingReasoningToContentTransition:
|
||||
DeltaMessage(reasoning="step 1"),
|
||||
DeltaMessage(reasoning=" step 2"),
|
||||
]
|
||||
call_count = 0
|
||||
|
||||
def mock_extract_reasoning_streaming(**kwargs):
|
||||
nonlocal call_count
|
||||
result = delta_sequence[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
|
||||
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
|
||||
serving.parser = MagicMock()
|
||||
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
|
||||
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
|
||||
_mock_parser_with_reasoning(serving, delta_sequence)
|
||||
|
||||
contexts = [
|
||||
_make_simple_context_with_output("chunk1", [10]),
|
||||
|
||||
@@ -1339,101 +1339,31 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
current_content_index = 0
|
||||
current_output_index = 0
|
||||
current_item_id = ""
|
||||
reasoning_parser = None
|
||||
if self.parser and self.parser.reasoning_parser_cls:
|
||||
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
|
||||
tool_parser = None
|
||||
if self.parser and self.parser.tool_parser_cls:
|
||||
tool_parser = self.parser.tool_parser_cls(tokenizer, request.tools)
|
||||
reasoning_ended = False
|
||||
tool_call_text_started = False
|
||||
previous_text = ""
|
||||
previous_token_ids: list[int] = []
|
||||
prompt_is_reasoning_end = None
|
||||
parser = self.parser(tokenizer, request.tools) if self.parser else None
|
||||
first_delta_sent = False
|
||||
previous_delta_messages: list[DeltaMessage] = []
|
||||
async for ctx in result_generator:
|
||||
assert isinstance(ctx, SimpleContext)
|
||||
if ctx.last_output is None:
|
||||
continue
|
||||
if reasoning_parser and prompt_is_reasoning_end is None:
|
||||
prompt_is_reasoning_end = reasoning_parser.is_reasoning_end(
|
||||
ctx.last_output.prompt_token_ids
|
||||
)
|
||||
if ctx.last_output.outputs:
|
||||
output = ctx.last_output.outputs[0]
|
||||
# finish_reason='error' indicates a retryable error
|
||||
self._raise_if_error(output.finish_reason, request.request_id)
|
||||
delta_text = output.text
|
||||
delta_token_ids = as_list(output.token_ids)
|
||||
current_text = previous_text + delta_text
|
||||
current_token_ids = previous_token_ids + delta_token_ids
|
||||
|
||||
if reasoning_parser and tool_parser:
|
||||
if prompt_is_reasoning_end:
|
||||
reasoning_ended = True
|
||||
if not reasoning_ended:
|
||||
delta_message = reasoning_parser.extract_reasoning_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
)
|
||||
if reasoning_parser.is_reasoning_end(delta_token_ids):
|
||||
reasoning_ended = True
|
||||
current_token_ids = reasoning_parser.extract_content_ids(
|
||||
delta_token_ids
|
||||
)
|
||||
if delta_message and delta_message.content:
|
||||
current_text = delta_message.content
|
||||
delta_message.content = None
|
||||
else:
|
||||
current_text = ""
|
||||
|
||||
if reasoning_ended:
|
||||
if not tool_call_text_started:
|
||||
tool_call_text_started = True
|
||||
previous_text = ""
|
||||
previous_token_ids = []
|
||||
delta_text = current_text
|
||||
delta_token_ids = current_token_ids
|
||||
|
||||
delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=request, # type: ignore[arg-type]
|
||||
)
|
||||
elif reasoning_parser:
|
||||
delta_message = reasoning_parser.extract_reasoning_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
if parser:
|
||||
delta_message = parser.parse_delta(
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
)
|
||||
elif tool_parser:
|
||||
delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=request, # type: ignore[arg-type]
|
||||
request=request,
|
||||
prompt_token_ids=ctx.last_output.prompt_token_ids,
|
||||
)
|
||||
else:
|
||||
delta_message = DeltaMessage(
|
||||
content=output.text,
|
||||
)
|
||||
previous_text = current_text
|
||||
previous_token_ids = current_token_ids
|
||||
if not delta_message:
|
||||
continue
|
||||
if not first_delta_sent:
|
||||
|
||||
@@ -5,6 +5,7 @@ import contextlib
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
|
||||
from openai.types.responses import (
|
||||
@@ -43,6 +44,17 @@ from vllm.utils import random_uuid
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamState:
|
||||
"""Mutable state for ``Parser.parse_delta()``. One per stream."""
|
||||
|
||||
reasoning_ended: bool = False
|
||||
tool_call_text_started: bool = False
|
||||
prompt_reasoning_checked: bool = False
|
||||
previous_text: str = ""
|
||||
previous_token_ids: list[int] = field(default_factory=list)
|
||||
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Abstract Parser class that unifies ReasoningParser and ToolParser into
|
||||
@@ -80,6 +92,7 @@ class Parser:
|
||||
self.model_tokenizer = tokenizer
|
||||
self._reasoning_parser: ReasoningParser | None = None
|
||||
self._tool_parser: ToolParser | None = None
|
||||
self._stream_state = StreamState()
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict[str, int]:
|
||||
@@ -291,6 +304,18 @@ class Parser:
|
||||
A DeltaMessage with tool_calls field, or None.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def parse_delta(
|
||||
self,
|
||||
delta_text: str,
|
||||
delta_token_ids: list[int],
|
||||
request: ChatCompletionRequest | ResponsesRequest,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
) -> DeltaMessage | None:
|
||||
"""Parse a single streaming delta, orchestrating reasoning then
|
||||
tool call extraction via internal stream state.
|
||||
"""
|
||||
|
||||
|
||||
class DelegatingParser(Parser):
|
||||
"""
|
||||
@@ -524,6 +549,100 @@ class DelegatingParser(Parser):
|
||||
request,
|
||||
)
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
if self._reasoning_parser is None:
|
||||
return False
|
||||
return self._reasoning_parser.is_reasoning_end(input_ids)
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
if self._reasoning_parser is None:
|
||||
return input_ids
|
||||
return self._reasoning_parser.extract_content_ids(input_ids)
|
||||
|
||||
def parse_delta(
|
||||
self,
|
||||
delta_text: str,
|
||||
delta_token_ids: list[int],
|
||||
request: ChatCompletionRequest | ResponsesRequest,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
) -> DeltaMessage | None:
|
||||
state = self._stream_state
|
||||
|
||||
if not state.prompt_reasoning_checked and prompt_token_ids is not None:
|
||||
state.prompt_reasoning_checked = True
|
||||
if self.is_reasoning_end(prompt_token_ids):
|
||||
state.reasoning_ended = True
|
||||
|
||||
current_text = state.previous_text + delta_text
|
||||
current_token_ids = state.previous_token_ids + delta_token_ids
|
||||
|
||||
delta_message: DeltaMessage | None = None
|
||||
|
||||
if self._reasoning_parser and self._tool_parser:
|
||||
if not state.reasoning_ended:
|
||||
delta_message = self.extract_reasoning_streaming(
|
||||
previous_text=state.previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=state.previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
)
|
||||
if self.is_reasoning_end(delta_token_ids):
|
||||
state.reasoning_ended = True
|
||||
current_token_ids = self.extract_content_ids(delta_token_ids)
|
||||
if delta_message and delta_message.content:
|
||||
current_text = delta_message.content
|
||||
delta_message.content = None
|
||||
else:
|
||||
current_text = ""
|
||||
|
||||
if state.reasoning_ended:
|
||||
if not state.tool_call_text_started:
|
||||
state.tool_call_text_started = True
|
||||
state.previous_text = ""
|
||||
state.previous_token_ids = []
|
||||
delta_text = current_text
|
||||
delta_token_ids = current_token_ids
|
||||
|
||||
delta_message = self.extract_tool_calls_streaming(
|
||||
previous_text=state.previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=state.previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=request, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
elif self._reasoning_parser:
|
||||
delta_message = self.extract_reasoning_streaming(
|
||||
previous_text=state.previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=state.previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
)
|
||||
|
||||
elif self._tool_parser:
|
||||
delta_message = self.extract_tool_calls_streaming(
|
||||
previous_text=state.previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=state.previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=request, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
|
||||
state.previous_text = current_text
|
||||
state.previous_token_ids = current_token_ids
|
||||
return delta_message
|
||||
|
||||
|
||||
class _WrappedParser(DelegatingParser):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user