[Bug] Fix a corner case in _process_simple_streaming_events (#34754)

Signed-off-by: Shiyan Deng <dsy842974287@meta.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
Shiyan Deng
2026-03-05 20:57:32 -08:00
committed by GitHub
parent 6dd302653f
commit 8e87cc57f1
2 changed files with 334 additions and 0 deletions

View File

@@ -6,6 +6,13 @@ from unittest.mock import MagicMock
import pytest
import pytest_asyncio
from openai.types.responses import (
ResponseOutputItemDoneEvent,
ResponseReasoningItem,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseTextDeltaEvent,
)
from openai.types.responses.tool import (
CodeInterpreterContainerCodeInterpreterToolAuto,
LocalShell,
@@ -16,6 +23,7 @@ from openai.types.responses.tool import (
import vllm.envs as envs
from vllm.entrypoints.mcp.tool_server import ToolServer
from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
ErrorResponse,
RequestResponseMetadata,
)
@@ -554,3 +562,309 @@ class TestHarmonyPreambleStreaming:
type_names = [e.type for e in events]
assert "response.output_text.done" not in type_names
def _make_simple_context_with_output(text, token_ids):
"""Create a SimpleContext with a RequestOutput containing the given text."""
ctx = SimpleContext()
completion = CompletionOutput(
index=0,
text=text,
token_ids=token_ids,
cumulative_logprob=0.0,
logprobs=None,
finish_reason=None,
stop_reason=None,
)
req_output = RequestOutput(
request_id="req",
prompt="hi",
prompt_token_ids=[7, 8],
prompt_logprobs=None,
outputs=[completion],
finished=False,
num_cached_tokens=0,
)
ctx.append_output(req_output)
return ctx
def _make_serving_instance_with_reasoning():
"""Create an OpenAIServingResponses with a mocked reasoning parser."""
engine_client = MagicMock()
model_config = MagicMock()
model_config.max_model_len = 100
model_config.hf_config.model_type = "test"
model_config.hf_text_config = MagicMock()
model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
models = MagicMock()
serving = OpenAIServingResponses(
engine_client=engine_client,
models=models,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
reasoning_parser="qwen3",
)
return serving
def _identity_increment(event):
"""Simple identity callable for _increment_sequence_number_and_return."""
seq = getattr(_identity_increment, "_counter", 0)
if hasattr(event, "sequence_number"):
event.sequence_number = seq
_identity_increment._counter = seq + 1 # type: ignore
return event
class TestStreamingReasoningToContentTransition:
"""Tests for _process_simple_streaming_events reasoning-to-content
transition, specifically the fix for mixed deltas that carry both
reasoning and content simultaneously."""
@pytest.mark.asyncio
async def test_mixed_delta_reasoning_and_content_emits_reasoning_delta(
self, monkeypatch
):
"""When the reasoning parser produces a delta with both reasoning
and content set (e.g. reasoning end and content start in the same
chunk), the trailing reasoning text must be emitted as a
ResponseReasoningTextDeltaEvent and included in the
ResponseReasoningTextDoneEvent text."""
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
serving = _make_serving_instance_with_reasoning()
# Sequence of DeltaMessages the mock reasoning parser will return
delta_sequence = [
DeltaMessage(reasoning="thinking..."),
DeltaMessage(reasoning=" end", content="hello"), # mixed delta
DeltaMessage(content=" world"),
]
call_count = 0
def mock_extract_reasoning_streaming(**kwargs):
nonlocal call_count
result = delta_sequence[call_count]
call_count += 1
return result
# Mock the reasoning parser on the serving instance
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
# Create contexts for each streaming chunk
contexts = [
_make_simple_context_with_output("chunk1", [10]),
_make_simple_context_with_output("chunk2", [20]),
_make_simple_context_with_output("chunk3", [30]),
]
async def result_generator():
for ctx in contexts:
yield ctx
request = ResponsesRequest(input="hi", tools=[], stream=True)
sampling_params = SamplingParams(max_tokens=64)
metadata = RequestResponseMetadata(request_id="req")
_identity_increment._counter = 0 # type: ignore
events = []
async for event in serving._process_simple_streaming_events(
request=request,
sampling_params=sampling_params,
result_generator=result_generator(),
context=SimpleContext(),
model_name="test-model",
tokenizer=MagicMock(),
request_metadata=metadata,
created_time=0,
_increment_sequence_number_and_return=_identity_increment,
):
events.append(event)
# The first reasoning delta should be emitted
reasoning_deltas = [
e for e in events if isinstance(e, ResponseReasoningTextDeltaEvent)
]
assert len(reasoning_deltas) == 2
assert reasoning_deltas[0].delta == "thinking..."
# The trailing reasoning from the mixed delta must also be emitted
assert reasoning_deltas[1].delta == " end"
# The done event must include both reasoning parts
reasoning_done = [
e for e in events if isinstance(e, ResponseReasoningTextDoneEvent)
]
assert len(reasoning_done) == 1
assert reasoning_done[0].text == "thinking... end"
# Content deltas should be emitted for both the mixed delta's
# content and the pure content delta
text_deltas = [e for e in events if isinstance(e, ResponseTextDeltaEvent)]
assert len(text_deltas) == 2
assert text_deltas[0].delta == "hello"
assert text_deltas[1].delta == " world"
@pytest.mark.asyncio
async def test_transition_without_mixed_delta_no_extra_reasoning_event(
self, monkeypatch
):
"""When the transition from reasoning to content is clean (no mixed
delta), no extra reasoning delta event should be emitted."""
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
serving = _make_serving_instance_with_reasoning()
delta_sequence = [
DeltaMessage(reasoning="thinking"),
DeltaMessage(content="answer"),
]
call_count = 0
def mock_extract_reasoning_streaming(**kwargs):
nonlocal call_count
result = delta_sequence[call_count]
call_count += 1
return result
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
contexts = [
_make_simple_context_with_output("chunk1", [10]),
_make_simple_context_with_output("chunk2", [20]),
]
async def result_generator():
for ctx in contexts:
yield ctx
request = ResponsesRequest(input="hi", tools=[], stream=True)
sampling_params = SamplingParams(max_tokens=64)
metadata = RequestResponseMetadata(request_id="req")
_identity_increment._counter = 0 # type: ignore
events = []
async for event in serving._process_simple_streaming_events(
request=request,
sampling_params=sampling_params,
result_generator=result_generator(),
context=SimpleContext(),
model_name="test-model",
tokenizer=MagicMock(),
request_metadata=metadata,
created_time=0,
_increment_sequence_number_and_return=_identity_increment,
):
events.append(event)
# Exactly one reasoning delta
reasoning_deltas = [
e for e in events if isinstance(e, ResponseReasoningTextDeltaEvent)
]
assert len(reasoning_deltas) == 1
assert reasoning_deltas[0].delta == "thinking"
# Done event has just "thinking"
reasoning_done = [
e for e in events if isinstance(e, ResponseReasoningTextDoneEvent)
]
assert len(reasoning_done) == 1
assert reasoning_done[0].text == "thinking"
# One content delta
text_deltas = [e for e in events if isinstance(e, ResponseTextDeltaEvent)]
assert len(text_deltas) == 1
assert text_deltas[0].delta == "answer"
@pytest.mark.asyncio
async def test_reasoning_only_stream_no_content(self, monkeypatch):
"""When the stream has only reasoning deltas and no content, the
reasoning done event should be emitted at finalization with the
full accumulated text, and no text delta events should appear."""
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
serving = _make_serving_instance_with_reasoning()
delta_sequence = [
DeltaMessage(reasoning="step 1"),
DeltaMessage(reasoning=" step 2"),
]
call_count = 0
def mock_extract_reasoning_streaming(**kwargs):
nonlocal call_count
result = delta_sequence[call_count]
call_count += 1
return result
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
contexts = [
_make_simple_context_with_output("chunk1", [10]),
_make_simple_context_with_output("chunk2", [20]),
]
async def result_generator():
for ctx in contexts:
yield ctx
request = ResponsesRequest(input="hi", tools=[], stream=True)
sampling_params = SamplingParams(max_tokens=64)
metadata = RequestResponseMetadata(request_id="req")
_identity_increment._counter = 0 # type: ignore
events = []
async for event in serving._process_simple_streaming_events(
request=request,
sampling_params=sampling_params,
result_generator=result_generator(),
context=SimpleContext(),
model_name="test-model",
tokenizer=MagicMock(),
request_metadata=metadata,
created_time=0,
_increment_sequence_number_and_return=_identity_increment,
):
events.append(event)
# Two reasoning deltas
reasoning_deltas = [
e for e in events if isinstance(e, ResponseReasoningTextDeltaEvent)
]
assert len(reasoning_deltas) == 2
assert reasoning_deltas[0].delta == "step 1"
assert reasoning_deltas[1].delta == " step 2"
# Done event at finalization with accumulated text
reasoning_done = [
e for e in events if isinstance(e, ResponseReasoningTextDoneEvent)
]
assert len(reasoning_done) == 1
assert reasoning_done[0].text == "step 1 step 2"
# No content text deltas
text_deltas = [e for e in events if isinstance(e, ResponseTextDeltaEvent)]
assert len(text_deltas) == 0
# Final item should be a reasoning item
item_done_events = [
e for e in events if isinstance(e, ResponseOutputItemDoneEvent)
]
assert len(item_done_events) == 1
assert isinstance(item_done_events[0].item, ResponseReasoningItem)