diff --git a/tests/entrypoints/test_context.py b/tests/entrypoints/test_context.py index f87683fc2..1ab2b5edb 100644 --- a/tests/entrypoints/test_context.py +++ b/tests/entrypoints/test_context.py @@ -8,6 +8,7 @@ from openai_harmony import Author, Message, Role, StreamState, TextContent from vllm.entrypoints.openai.responses.context import ( HarmonyContext, + SimpleContext, StreamingHarmonyContext, TurnMetrics, ) @@ -597,3 +598,248 @@ def test_turn_metrics_copy_and_reset(): assert copied_metrics.output_tokens == 20 assert copied_metrics.cached_input_tokens == 5 assert copied_metrics.tool_output_tokens == 3 + + +# ==================== SimpleContext Tests ==================== + + +def create_simple_context_output( + text="", + token_ids=None, + prompt="Test prompt", + prompt_token_ids=None, + num_cached_tokens=0, + logprobs=None, + finished=True, +): + """Helper to create a RequestOutput with customizable text for + SimpleContext tests.""" + if token_ids is None: + token_ids = [] + return RequestOutput( + request_id="test-id", + prompt=prompt, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=0, + text=text, + token_ids=token_ids, + cumulative_logprob=0.0, + logprobs=logprobs, + finish_reason=None, + stop_reason=None, + ) + ], + finished=finished, + num_cached_tokens=num_cached_tokens, + ) + + +def test_simple_context_output_messages_empty(): + """output_messages should be empty before any output is appended.""" + context = SimpleContext() + assert context.output_messages == [] + + +def test_simple_context_output_messages_single_call(): + """Non-streaming: single append_output produces a single output message.""" + context = SimpleContext() + output = create_simple_context_output( + text="Hello world", + token_ids=[10, 20, 30], + prompt_token_ids=[1, 2, 3], + ) + context.append_output(output) + + messages = context.output_messages + assert len(messages) == 1 + assert messages[0].message == "Hello world" + assert messages[0].tokens == [10, 20, 30] + assert messages[0].type == "raw_message_tokens" + + +def test_simple_context_output_messages_streaming_consolidation(): + """Streaming: multiple append_output calls consolidate into one message.""" + context = SimpleContext() + + # Simulate 3 streaming deltas + context.append_output( + create_simple_context_output( + text="Hello", + token_ids=[10], + prompt_token_ids=[1, 2, 3], + ) + ) + context.append_output( + create_simple_context_output( + text=" world", + token_ids=[20], + prompt_token_ids=[1, 2, 3], + ) + ) + context.append_output( + create_simple_context_output( + text="!", + token_ids=[30], + prompt_token_ids=[1, 2, 3], + ) + ) + + messages = context.output_messages + assert len(messages) == 1 + assert messages[0].message == "Hello world!" + assert messages[0].tokens == [10, 20, 30] + + +def test_simple_context_output_messages_many_deltas(): + """Streaming with many small deltas still produces a single message.""" + context = SimpleContext() + + words = ["The", " quick", " brown", " fox", " jumps"] + for i, word in enumerate(words): + context.append_output( + create_simple_context_output( + text=word, + token_ids=[100 + i], + prompt_token_ids=[1, 2], + ) + ) + + messages = context.output_messages + assert len(messages) == 1 + assert messages[0].message == "The quick brown fox jumps" + assert messages[0].tokens == [100, 101, 102, 103, 104] + + +def test_simple_context_input_messages(): + """input_messages is populated on the first append_output call.""" + context = SimpleContext() + assert context.input_messages == [] + + context.append_output( + create_simple_context_output( + text="Hi", + token_ids=[10], + prompt="My prompt text", + prompt_token_ids=[1, 2, 3], + ) + ) + + assert len(context.input_messages) == 1 + assert context.input_messages[0].message == "My prompt text" + assert context.input_messages[0].tokens == [1, 2, 3] + + # Second call should not add another input message + context.append_output( + create_simple_context_output( + text=" there", + token_ids=[20], + prompt="My prompt text", + prompt_token_ids=[1, 2, 3], + ) + ) + + assert len(context.input_messages) == 1 + + +def test_simple_context_token_counting(): + """Token counting accumulates across streaming deltas.""" + context = SimpleContext() + + context.append_output( + create_simple_context_output( + text="a", + token_ids=[10, 11], + prompt_token_ids=[1, 2, 3, 4, 5], + num_cached_tokens=2, + ) + ) + context.append_output( + create_simple_context_output( + text="b", + token_ids=[12], + prompt_token_ids=[1, 2, 3, 4, 5], + num_cached_tokens=2, + ) + ) + + assert context.num_prompt_tokens == 5 + assert context.num_output_tokens == 3 # 2 + 1 + assert context.num_cached_tokens == 2 + + +def test_simple_context_final_output(): + """final_output reconstructs accumulated text and token_ids.""" + context = SimpleContext() + + context.append_output( + create_simple_context_output( + text="foo", + token_ids=[1, 2], + prompt_token_ids=[10], + ) + ) + context.append_output( + create_simple_context_output( + text="bar", + token_ids=[3], + prompt_token_ids=[10], + ) + ) + + final = context.final_output + assert final is not None + assert final.outputs[0].text == "foobar" + assert final.outputs[0].token_ids == (1, 2, 3) + + +def test_simple_context_output_messages_empty_text_with_tokens(): + """output_messages should be returned when tokens exist even if text is + empty (e.g. special tokens).""" + context = SimpleContext() + context.append_output( + create_simple_context_output( + text="", + token_ids=[99], + prompt_token_ids=[1], + ) + ) + + messages = context.output_messages + assert len(messages) == 1 + assert messages[0].message == "" + assert messages[0].tokens == [99] + + +def test_simple_context_output_messages_no_mutation(): + """Each call to output_messages returns a fresh list; callers can't + corrupt internal state.""" + context = SimpleContext() + context.append_output( + create_simple_context_output( + text="hello", + token_ids=[1], + prompt_token_ids=[10], + ) + ) + + msgs1 = context.output_messages + msgs2 = context.output_messages + assert msgs1 is not msgs2 + assert msgs1[0].message == msgs2[0].message + + # Appending more output updates the property + context.append_output( + create_simple_context_output( + text=" world", + token_ids=[2], + prompt_token_ids=[10], + ) + ) + + msgs3 = context.output_messages + assert len(msgs3) == 1 + assert msgs3[0].message == "hello world" + assert msgs3[0].tokens == [1, 2] diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 1fbf19add..a91bc694b 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -1379,6 +1379,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser): action="store_true", help="Disable shuffling of dataset samples for deterministic ordering.", ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from HuggingFace.", + ) # group for dataset specific arguments custom_group = parser.add_argument_group("custom dataset options") diff --git a/vllm/entrypoints/openai/responses/context.py b/vllm/entrypoints/openai/responses/context.py index a10567e40..b327c1e1b 100644 --- a/vllm/entrypoints/openai/responses/context.py +++ b/vllm/entrypoints/openai/responses/context.py @@ -182,7 +182,6 @@ class SimpleContext(ConversationContext): self.all_turn_metrics = [] self.input_messages: list[ResponseRawMessageAndToken] = [] - self.output_messages: list[ResponseRawMessageAndToken] = [] def append_output(self, output) -> None: self.last_output = output @@ -208,12 +207,22 @@ class SimpleContext(ConversationContext): tokens=output_prompt_token_ids, ) ) - self.output_messages.append( + + @property + def output_messages(self) -> list[ResponseRawMessageAndToken]: + """Return consolidated output as a single message. + + In streaming mode, text and tokens are accumulated across many deltas. + This property returns them as a single entry rather than one per delta. + """ + if not self._accumulated_text and not self._accumulated_token_ids: + return [] + return [ ResponseRawMessageAndToken( - message=delta_output.text, - tokens=delta_output.token_ids, + message=self._accumulated_text, + tokens=list(self._accumulated_token_ids), ) - ) + ] @property def final_output(self) -> RequestOutput | None: