[Bug] Fix a corner case in _process_simple_streaming_events (#34754)
Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,13 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from openai.types.responses import (
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningTextDeltaEvent,
|
||||
ResponseReasoningTextDoneEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
)
|
||||
from openai.types.responses.tool import (
|
||||
CodeInterpreterContainerCodeInterpreterToolAuto,
|
||||
LocalShell,
|
||||
@@ -16,6 +23,7 @@ from openai.types.responses.tool import (
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.mcp.tool_server import ToolServer
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
@@ -554,3 +562,309 @@ class TestHarmonyPreambleStreaming:
|
||||
|
||||
type_names = [e.type for e in events]
|
||||
assert "response.output_text.done" not in type_names
|
||||
|
||||
|
||||
def _make_simple_context_with_output(text, token_ids):
|
||||
"""Create a SimpleContext with a RequestOutput containing the given text."""
|
||||
ctx = SimpleContext()
|
||||
completion = CompletionOutput(
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
cumulative_logprob=0.0,
|
||||
logprobs=None,
|
||||
finish_reason=None,
|
||||
stop_reason=None,
|
||||
)
|
||||
req_output = RequestOutput(
|
||||
request_id="req",
|
||||
prompt="hi",
|
||||
prompt_token_ids=[7, 8],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion],
|
||||
finished=False,
|
||||
num_cached_tokens=0,
|
||||
)
|
||||
ctx.append_output(req_output)
|
||||
return ctx
|
||||
|
||||
|
||||
def _make_serving_instance_with_reasoning():
|
||||
"""Create an OpenAIServingResponses with a mocked reasoning parser."""
|
||||
engine_client = MagicMock()
|
||||
model_config = MagicMock()
|
||||
model_config.max_model_len = 100
|
||||
model_config.hf_config.model_type = "test"
|
||||
model_config.hf_text_config = MagicMock()
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
engine_client.model_config = model_config
|
||||
engine_client.input_processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
engine_client.renderer = MagicMock()
|
||||
|
||||
models = MagicMock()
|
||||
|
||||
serving = OpenAIServingResponses(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
reasoning_parser="qwen3",
|
||||
)
|
||||
return serving
|
||||
|
||||
|
||||
def _identity_increment(event):
|
||||
"""Simple identity callable for _increment_sequence_number_and_return."""
|
||||
seq = getattr(_identity_increment, "_counter", 0)
|
||||
if hasattr(event, "sequence_number"):
|
||||
event.sequence_number = seq
|
||||
_identity_increment._counter = seq + 1 # type: ignore
|
||||
return event
|
||||
|
||||
|
||||
class TestStreamingReasoningToContentTransition:
|
||||
"""Tests for _process_simple_streaming_events reasoning-to-content
|
||||
transition, specifically the fix for mixed deltas that carry both
|
||||
reasoning and content simultaneously."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_delta_reasoning_and_content_emits_reasoning_delta(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""When the reasoning parser produces a delta with both reasoning
|
||||
and content set (e.g. reasoning end and content start in the same
|
||||
chunk), the trailing reasoning text must be emitted as a
|
||||
ResponseReasoningTextDeltaEvent and included in the
|
||||
ResponseReasoningTextDoneEvent text."""
|
||||
|
||||
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
|
||||
serving = _make_serving_instance_with_reasoning()
|
||||
|
||||
# Sequence of DeltaMessages the mock reasoning parser will return
|
||||
delta_sequence = [
|
||||
DeltaMessage(reasoning="thinking..."),
|
||||
DeltaMessage(reasoning=" end", content="hello"), # mixed delta
|
||||
DeltaMessage(content=" world"),
|
||||
]
|
||||
call_count = 0
|
||||
|
||||
def mock_extract_reasoning_streaming(**kwargs):
|
||||
nonlocal call_count
|
||||
result = delta_sequence[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
# Mock the reasoning parser on the serving instance
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
|
||||
serving.parser = MagicMock()
|
||||
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
|
||||
|
||||
# Create contexts for each streaming chunk
|
||||
contexts = [
|
||||
_make_simple_context_with_output("chunk1", [10]),
|
||||
_make_simple_context_with_output("chunk2", [20]),
|
||||
_make_simple_context_with_output("chunk3", [30]),
|
||||
]
|
||||
|
||||
async def result_generator():
|
||||
for ctx in contexts:
|
||||
yield ctx
|
||||
|
||||
request = ResponsesRequest(input="hi", tools=[], stream=True)
|
||||
sampling_params = SamplingParams(max_tokens=64)
|
||||
metadata = RequestResponseMetadata(request_id="req")
|
||||
_identity_increment._counter = 0 # type: ignore
|
||||
|
||||
events = []
|
||||
async for event in serving._process_simple_streaming_events(
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
result_generator=result_generator(),
|
||||
context=SimpleContext(),
|
||||
model_name="test-model",
|
||||
tokenizer=MagicMock(),
|
||||
request_metadata=metadata,
|
||||
created_time=0,
|
||||
_increment_sequence_number_and_return=_identity_increment,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# The first reasoning delta should be emitted
|
||||
reasoning_deltas = [
|
||||
e for e in events if isinstance(e, ResponseReasoningTextDeltaEvent)
|
||||
]
|
||||
assert len(reasoning_deltas) == 2
|
||||
assert reasoning_deltas[0].delta == "thinking..."
|
||||
# The trailing reasoning from the mixed delta must also be emitted
|
||||
assert reasoning_deltas[1].delta == " end"
|
||||
|
||||
# The done event must include both reasoning parts
|
||||
reasoning_done = [
|
||||
e for e in events if isinstance(e, ResponseReasoningTextDoneEvent)
|
||||
]
|
||||
assert len(reasoning_done) == 1
|
||||
assert reasoning_done[0].text == "thinking... end"
|
||||
|
||||
# Content deltas should be emitted for both the mixed delta's
|
||||
# content and the pure content delta
|
||||
text_deltas = [e for e in events if isinstance(e, ResponseTextDeltaEvent)]
|
||||
assert len(text_deltas) == 2
|
||||
assert text_deltas[0].delta == "hello"
|
||||
assert text_deltas[1].delta == " world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transition_without_mixed_delta_no_extra_reasoning_event(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""When the transition from reasoning to content is clean (no mixed
|
||||
delta), no extra reasoning delta event should be emitted."""
|
||||
|
||||
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
|
||||
serving = _make_serving_instance_with_reasoning()
|
||||
|
||||
delta_sequence = [
|
||||
DeltaMessage(reasoning="thinking"),
|
||||
DeltaMessage(content="answer"),
|
||||
]
|
||||
call_count = 0
|
||||
|
||||
def mock_extract_reasoning_streaming(**kwargs):
|
||||
nonlocal call_count
|
||||
result = delta_sequence[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
|
||||
serving.parser = MagicMock()
|
||||
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
|
||||
|
||||
contexts = [
|
||||
_make_simple_context_with_output("chunk1", [10]),
|
||||
_make_simple_context_with_output("chunk2", [20]),
|
||||
]
|
||||
|
||||
async def result_generator():
|
||||
for ctx in contexts:
|
||||
yield ctx
|
||||
|
||||
request = ResponsesRequest(input="hi", tools=[], stream=True)
|
||||
sampling_params = SamplingParams(max_tokens=64)
|
||||
metadata = RequestResponseMetadata(request_id="req")
|
||||
_identity_increment._counter = 0 # type: ignore
|
||||
|
||||
events = []
|
||||
async for event in serving._process_simple_streaming_events(
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
result_generator=result_generator(),
|
||||
context=SimpleContext(),
|
||||
model_name="test-model",
|
||||
tokenizer=MagicMock(),
|
||||
request_metadata=metadata,
|
||||
created_time=0,
|
||||
_increment_sequence_number_and_return=_identity_increment,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Exactly one reasoning delta
|
||||
reasoning_deltas = [
|
||||
e for e in events if isinstance(e, ResponseReasoningTextDeltaEvent)
|
||||
]
|
||||
assert len(reasoning_deltas) == 1
|
||||
assert reasoning_deltas[0].delta == "thinking"
|
||||
|
||||
# Done event has just "thinking"
|
||||
reasoning_done = [
|
||||
e for e in events if isinstance(e, ResponseReasoningTextDoneEvent)
|
||||
]
|
||||
assert len(reasoning_done) == 1
|
||||
assert reasoning_done[0].text == "thinking"
|
||||
|
||||
# One content delta
|
||||
text_deltas = [e for e in events if isinstance(e, ResponseTextDeltaEvent)]
|
||||
assert len(text_deltas) == 1
|
||||
assert text_deltas[0].delta == "answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_only_stream_no_content(self, monkeypatch):
|
||||
"""When the stream has only reasoning deltas and no content, the
|
||||
reasoning done event should be emitted at finalization with the
|
||||
full accumulated text, and no text delta events should appear."""
|
||||
|
||||
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
|
||||
serving = _make_serving_instance_with_reasoning()
|
||||
|
||||
delta_sequence = [
|
||||
DeltaMessage(reasoning="step 1"),
|
||||
DeltaMessage(reasoning=" step 2"),
|
||||
]
|
||||
call_count = 0
|
||||
|
||||
def mock_extract_reasoning_streaming(**kwargs):
|
||||
nonlocal call_count
|
||||
result = delta_sequence[call_count]
|
||||
call_count += 1
|
||||
return result
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
|
||||
serving.parser = MagicMock()
|
||||
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
|
||||
|
||||
contexts = [
|
||||
_make_simple_context_with_output("chunk1", [10]),
|
||||
_make_simple_context_with_output("chunk2", [20]),
|
||||
]
|
||||
|
||||
async def result_generator():
|
||||
for ctx in contexts:
|
||||
yield ctx
|
||||
|
||||
request = ResponsesRequest(input="hi", tools=[], stream=True)
|
||||
sampling_params = SamplingParams(max_tokens=64)
|
||||
metadata = RequestResponseMetadata(request_id="req")
|
||||
_identity_increment._counter = 0 # type: ignore
|
||||
|
||||
events = []
|
||||
async for event in serving._process_simple_streaming_events(
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
result_generator=result_generator(),
|
||||
context=SimpleContext(),
|
||||
model_name="test-model",
|
||||
tokenizer=MagicMock(),
|
||||
request_metadata=metadata,
|
||||
created_time=0,
|
||||
_increment_sequence_number_and_return=_identity_increment,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Two reasoning deltas
|
||||
reasoning_deltas = [
|
||||
e for e in events if isinstance(e, ResponseReasoningTextDeltaEvent)
|
||||
]
|
||||
assert len(reasoning_deltas) == 2
|
||||
assert reasoning_deltas[0].delta == "step 1"
|
||||
assert reasoning_deltas[1].delta == " step 2"
|
||||
|
||||
# Done event at finalization with accumulated text
|
||||
reasoning_done = [
|
||||
e for e in events if isinstance(e, ResponseReasoningTextDoneEvent)
|
||||
]
|
||||
assert len(reasoning_done) == 1
|
||||
assert reasoning_done[0].text == "step 1 step 2"
|
||||
|
||||
# No content text deltas
|
||||
text_deltas = [e for e in events if isinstance(e, ResponseTextDeltaEvent)]
|
||||
assert len(text_deltas) == 0
|
||||
|
||||
# Final item should be a reasoning item
|
||||
item_done_events = [
|
||||
e for e in events if isinstance(e, ResponseOutputItemDoneEvent)
|
||||
]
|
||||
assert len(item_done_events) == 1
|
||||
assert isinstance(item_done_events[0].item, ResponseReasoningItem)
|
||||
|
||||
Reference in New Issue
Block a user