diff --git a/tests/entrypoints/openai/chat_completion/test_enable_force_include_usage.py b/tests/entrypoints/openai/chat_completion/test_enable_force_include_usage.py index 0d53b545d..1bc545e86 100644 --- a/tests/entrypoints/openai/chat_completion/test_enable_force_include_usage.py +++ b/tests/entrypoints/openai/chat_completion/test_enable_force_include_usage.py @@ -54,21 +54,19 @@ async def test_chat_with_enable_force_include_usage( ) last_completion_tokens = 0 async for chunk in stream: - if not len(chunk.choices): - assert chunk.usage.prompt_tokens >= 0 - assert ( - last_completion_tokens == 0 - or chunk.usage.completion_tokens > last_completion_tokens - or ( - not chunk.choices - and chunk.usage.completion_tokens == last_completion_tokens - ) + assert chunk.usage.prompt_tokens >= 0 + assert ( + last_completion_tokens == 0 + or chunk.usage.completion_tokens > last_completion_tokens + or ( + not chunk.choices + and chunk.usage.completion_tokens == last_completion_tokens ) - assert chunk.usage.total_tokens == ( - chunk.usage.prompt_tokens + chunk.usage.completion_tokens - ) - else: - assert chunk.usage is None + ) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) + last_completion_tokens = chunk.usage.completion_tokens @pytest.fixture(scope="module") diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py index 725938339..ff65066ff 100644 --- a/tests/entrypoints/test_utils.py +++ b/tests/entrypoints/test_utils.py @@ -3,7 +3,12 @@ import pytest -from vllm.entrypoints.utils import get_max_tokens, sanitize_message +from vllm.entrypoints.openai.engine.protocol import StreamOptions +from vllm.entrypoints.utils import ( + get_max_tokens, + sanitize_message, + should_include_usage, +) def test_sanitize_message(): @@ -13,6 +18,25 @@ def test_sanitize_message(): ) +@pytest.mark.parametrize( + ("stream_options", "expected"), + [ + (None, (True, True)), + (StreamOptions(include_usage=False), (True, True)), + ( + StreamOptions(include_usage=False, continuous_usage_stats=False), + (True, True), + ), + ( + StreamOptions(include_usage=True, continuous_usage_stats=False), + (True, True), + ), + ], +) +def test_should_include_usage_force_enables_continuous_usage(stream_options, expected): + assert should_include_usage(stream_options, True) == expected + + class TestGetMaxTokens: """Tests for get_max_tokens() to ensure generation_config's max_tokens acts as a default when from model author, and as a ceiling when diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index d5ecb7599..1c5abecda 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -236,13 +236,15 @@ def log_non_default_args(args: Namespace | EngineArgs): def should_include_usage( stream_options: "StreamOptions | None", enable_force_include_usage: bool ) -> tuple[bool, bool]: + if enable_force_include_usage: + return True, True if stream_options: - include_usage = stream_options.include_usage or enable_force_include_usage + include_usage = bool(stream_options.include_usage) include_continuous_usage = include_usage and bool( stream_options.continuous_usage_stats ) else: - include_usage, include_continuous_usage = enable_force_include_usage, False + include_usage, include_continuous_usage = False, False return include_usage, include_continuous_usage