feat: Add --enable-log-outputs flag for logging model generations (#20707)
Signed-off-by: Adrian Garcia <adrian.garcia@inceptionai.ai>
This commit is contained in:
committed by
GitHub
parent
82216dc21f
commit
8e8e0b6af1
@@ -10,11 +10,12 @@ from dataclasses import dataclass
|
|||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger,
|
from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger,
|
||||||
enable_trace_function_call, init_logger)
|
enable_trace_function_call, init_logger)
|
||||||
from vllm.logging_utils import NewLineFormatter
|
from vllm.logging_utils import NewLineFormatter
|
||||||
@@ -228,9 +229,10 @@ def test_prepare_object_to_dump():
|
|||||||
list_obj = [1, 2, 3]
|
list_obj = [1, 2, 3]
|
||||||
assert prepare_object_to_dump(list_obj) == '[1, 2, 3]'
|
assert prepare_object_to_dump(list_obj) == '[1, 2, 3]'
|
||||||
|
|
||||||
dict_obj = {'a': 1, 'b': 'b'}
|
dict_obj = {"a": 1, "b": "b"}
|
||||||
assert prepare_object_to_dump(dict_obj) in [
|
assert prepare_object_to_dump(dict_obj) in [
|
||||||
"{a: 1, b: 'b'}", "{b: 'b', a: 1}"
|
"{a: 1, b: 'b'}",
|
||||||
|
"{b: 'b', a: 1}",
|
||||||
]
|
]
|
||||||
|
|
||||||
set_obj = {1, 2, 3}
|
set_obj = {1, 2, 3}
|
||||||
@@ -252,4 +254,246 @@ def test_prepare_object_to_dump():
|
|||||||
b: str
|
b: str
|
||||||
|
|
||||||
assert (prepare_object_to_dump(CustomClass(
|
assert (prepare_object_to_dump(CustomClass(
|
||||||
1, 'b')) == "CustomClass(a=1, b='b')")
|
1, "b")) == "CustomClass(a=1, b='b')")
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs():
|
||||||
|
"""Test the new log_outputs functionality."""
|
||||||
|
# Create a mock logger to capture log calls
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test basic output logging
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-123",
|
||||||
|
outputs="Hello, world!",
|
||||||
|
output_token_ids=[1, 2, 3, 4],
|
||||||
|
finish_reason="stop",
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
assert "Generated response %s%s" in call_args[0]
|
||||||
|
assert call_args[1] == "test-123"
|
||||||
|
assert call_args[3] == "Hello, world!"
|
||||||
|
assert call_args[4] == [1, 2, 3, 4]
|
||||||
|
assert call_args[5] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_streaming_delta():
|
||||||
|
"""Test log_outputs with streaming delta mode."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test streaming delta logging
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-456",
|
||||||
|
outputs="Hello",
|
||||||
|
output_token_ids=[1],
|
||||||
|
finish_reason=None,
|
||||||
|
is_streaming=True,
|
||||||
|
delta=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
assert "Generated response %s%s" in call_args[0]
|
||||||
|
assert call_args[1] == "test-456"
|
||||||
|
assert call_args[2] == " (streaming delta)"
|
||||||
|
assert call_args[3] == "Hello"
|
||||||
|
assert call_args[4] == [1]
|
||||||
|
assert call_args[5] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_streaming_complete():
|
||||||
|
"""Test log_outputs with streaming complete mode."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test streaming complete logging
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-789",
|
||||||
|
outputs="Complete response",
|
||||||
|
output_token_ids=[1, 2, 3],
|
||||||
|
finish_reason="length",
|
||||||
|
is_streaming=True,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
assert "Generated response %s%s" in call_args[0]
|
||||||
|
assert call_args[1] == "test-789"
|
||||||
|
assert call_args[2] == " (streaming complete)"
|
||||||
|
assert call_args[3] == "Complete response"
|
||||||
|
assert call_args[4] == [1, 2, 3]
|
||||||
|
assert call_args[5] == "length"
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_with_truncation():
|
||||||
|
"""Test log_outputs respects max_log_len setting."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||||
|
# Set max_log_len to 10
|
||||||
|
request_logger = RequestLogger(max_log_len=10)
|
||||||
|
|
||||||
|
# Test output truncation
|
||||||
|
long_output = "This is a very long output that should be truncated"
|
||||||
|
long_token_ids = list(range(20)) # 20 tokens
|
||||||
|
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-truncate",
|
||||||
|
outputs=long_output,
|
||||||
|
output_token_ids=long_token_ids,
|
||||||
|
finish_reason="stop",
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args
|
||||||
|
|
||||||
|
# Check that output was truncated to first 10 characters
|
||||||
|
logged_output = call_args[0][3]
|
||||||
|
assert logged_output == "This is a "
|
||||||
|
assert len(logged_output) == 10
|
||||||
|
|
||||||
|
# Check that token IDs were truncated to first 10 tokens
|
||||||
|
logged_token_ids = call_args[0][4]
|
||||||
|
assert logged_token_ids == list(range(10))
|
||||||
|
assert len(logged_token_ids) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_none_values():
|
||||||
|
"""Test log_outputs handles None values correctly."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test with None output_token_ids
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-none",
|
||||||
|
outputs="Test output",
|
||||||
|
output_token_ids=None,
|
||||||
|
finish_reason="stop",
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
assert "Generated response %s%s" in call_args[0]
|
||||||
|
assert call_args[1] == "test-none"
|
||||||
|
assert call_args[3] == "Test output"
|
||||||
|
assert call_args[4] is None
|
||||||
|
assert call_args[5] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_empty_output():
|
||||||
|
"""Test log_outputs handles empty output correctly."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=5)
|
||||||
|
|
||||||
|
# Test with empty output
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-empty",
|
||||||
|
outputs="",
|
||||||
|
output_token_ids=[],
|
||||||
|
finish_reason="stop",
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
assert "Generated response %s%s" in call_args[0]
|
||||||
|
assert call_args[1] == "test-empty"
|
||||||
|
assert call_args[3] == ""
|
||||||
|
assert call_args[4] == []
|
||||||
|
assert call_args[5] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_integration():
|
||||||
|
"""Test that log_outputs can be called alongside log_inputs."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test that both methods can be called without interference
|
||||||
|
request_logger.log_inputs(
|
||||||
|
request_id="test-integration",
|
||||||
|
prompt="Test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_embeds=None,
|
||||||
|
params=None,
|
||||||
|
lora_request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-integration",
|
||||||
|
outputs="Test output",
|
||||||
|
output_token_ids=[4, 5, 6],
|
||||||
|
finish_reason="stop",
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have been called twice - once for inputs, once for outputs
|
||||||
|
assert mock_logger.info.call_count == 2
|
||||||
|
|
||||||
|
# Check that the calls were made with correct patterns
|
||||||
|
input_call = mock_logger.info.call_args_list[0][0]
|
||||||
|
output_call = mock_logger.info.call_args_list[1][0]
|
||||||
|
|
||||||
|
assert "Received request %s" in input_call[0]
|
||||||
|
assert input_call[1] == "test-integration"
|
||||||
|
|
||||||
|
assert "Generated response %s%s" in output_call[0]
|
||||||
|
assert output_call[1] == "test-integration"
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_complete_logs_full_text_content():
|
||||||
|
"""Test that streaming complete logging includes
|
||||||
|
full accumulated text, not just token count."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test with actual content instead of token count format
|
||||||
|
full_response = "This is a complete response from streaming"
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-streaming-full-text",
|
||||||
|
outputs=full_response,
|
||||||
|
output_token_ids=None,
|
||||||
|
finish_reason="streaming_complete",
|
||||||
|
is_streaming=True,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
|
||||||
|
# Verify the logged output is the full text, not a token count format
|
||||||
|
logged_output = call_args[3]
|
||||||
|
assert logged_output == full_response
|
||||||
|
assert "tokens>" not in logged_output
|
||||||
|
assert "streaming_complete" not in logged_output
|
||||||
|
|
||||||
|
# Verify other parameters
|
||||||
|
assert call_args[1] == "test-streaming-full-text"
|
||||||
|
assert call_args[2] == " (streaming complete)"
|
||||||
|
assert call_args[5] == "streaming_complete"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -16,8 +17,6 @@ logger = init_logger(__name__)
|
|||||||
class RequestLogger:
|
class RequestLogger:
|
||||||
|
|
||||||
def __init__(self, *, max_log_len: Optional[int]) -> None:
|
def __init__(self, *, max_log_len: Optional[int]) -> None:
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.max_log_len = max_log_len
|
self.max_log_len = max_log_len
|
||||||
|
|
||||||
def log_inputs(
|
def log_inputs(
|
||||||
@@ -45,3 +44,36 @@ class RequestLogger:
|
|||||||
"lora_request: %s.", request_id, prompt, params, prompt_token_ids,
|
"lora_request: %s.", request_id, prompt, params, prompt_token_ids,
|
||||||
prompt_embeds.shape if prompt_embeds is not None else None,
|
prompt_embeds.shape if prompt_embeds is not None else None,
|
||||||
lora_request)
|
lora_request)
|
||||||
|
|
||||||
|
def log_outputs(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
outputs: str,
|
||||||
|
output_token_ids: Optional[Sequence[int]],
|
||||||
|
finish_reason: Optional[str] = None,
|
||||||
|
is_streaming: bool = False,
|
||||||
|
delta: bool = False,
|
||||||
|
) -> None:
|
||||||
|
max_log_len = self.max_log_len
|
||||||
|
if max_log_len is not None:
|
||||||
|
if outputs is not None:
|
||||||
|
outputs = outputs[:max_log_len]
|
||||||
|
|
||||||
|
if output_token_ids is not None:
|
||||||
|
# Convert to list and apply truncation
|
||||||
|
output_token_ids = list(output_token_ids)[:max_log_len]
|
||||||
|
|
||||||
|
stream_info = ""
|
||||||
|
if is_streaming:
|
||||||
|
stream_info = (" (streaming delta)"
|
||||||
|
if delta else " (streaming complete)")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Generated response %s%s: output: %r, "
|
||||||
|
"output_token_ids: %s, finish_reason: %s",
|
||||||
|
request_id,
|
||||||
|
stream_info,
|
||||||
|
outputs,
|
||||||
|
output_token_ids,
|
||||||
|
finish_reason,
|
||||||
|
)
|
||||||
|
|||||||
@@ -44,10 +44,10 @@ class LoRAParserAction(argparse.Action):
|
|||||||
|
|
||||||
lora_list: list[LoRAModulePath] = []
|
lora_list: list[LoRAModulePath] = []
|
||||||
for item in values:
|
for item in values:
|
||||||
if item in [None, '']: # Skip if item is None or empty string
|
if item in [None, ""]: # Skip if item is None or empty string
|
||||||
continue
|
continue
|
||||||
if '=' in item and ',' not in item: # Old format: name=path
|
if "=" in item and "," not in item: # Old format: name=path
|
||||||
name, path = item.split('=')
|
name, path = item.split("=")
|
||||||
lora_list.append(LoRAModulePath(name, path))
|
lora_list.append(LoRAModulePath(name, path))
|
||||||
else: # Assume JSON format
|
else: # Assume JSON format
|
||||||
try:
|
try:
|
||||||
@@ -167,6 +167,9 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
|||||||
enable_tokenizer_info_endpoint: bool = False
|
enable_tokenizer_info_endpoint: bool = False
|
||||||
"""Enable the /get_tokenizer_info endpoint. May expose chat
|
"""Enable the /get_tokenizer_info endpoint. May expose chat
|
||||||
templates and other tokenizer configuration."""
|
templates and other tokenizer configuration."""
|
||||||
|
enable_log_outputs: bool = False
|
||||||
|
"""If set to True, enable logging of model outputs (generations)
|
||||||
|
in addition to the input logging that is enabled by default."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tool_parser: Optional[str] = None,
|
tool_parser: Optional[str] = None,
|
||||||
enable_prompt_tokens_details: bool = False,
|
enable_prompt_tokens_details: bool = False,
|
||||||
enable_force_include_usage: bool = False,
|
enable_force_include_usage: bool = False,
|
||||||
|
enable_log_outputs: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@@ -84,6 +85,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
self.response_role = response_role
|
self.response_role = response_role
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.chat_template_content_format: Final = chat_template_content_format
|
self.chat_template_content_format: Final = chat_template_content_format
|
||||||
|
self.enable_log_outputs = enable_log_outputs
|
||||||
|
|
||||||
# set up tool use
|
# set up tool use
|
||||||
self.enable_auto_tools: bool = enable_auto_tools
|
self.enable_auto_tools: bool = enable_auto_tools
|
||||||
@@ -489,20 +491,21 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
all_previous_token_ids: Optional[list[list[int]]]
|
all_previous_token_ids: Optional[list[list[int]]]
|
||||||
function_name_returned = [False] * num_choices
|
function_name_returned = [False] * num_choices
|
||||||
|
|
||||||
|
# Always track previous_texts for comprehensive output logging
|
||||||
|
previous_texts = [""] * num_choices
|
||||||
|
|
||||||
# Only one of these will be used, thus previous_texts and
|
# Only one of these will be used, thus previous_texts and
|
||||||
# all_previous_token_ids will not be used twice in the same iteration.
|
# all_previous_token_ids will not be used twice in the same iteration.
|
||||||
if tool_choice_auto or self.reasoning_parser:
|
if tool_choice_auto or self.reasoning_parser:
|
||||||
# These are only required in "auto" tool choice case
|
# These are only required in "auto" tool choice case
|
||||||
previous_texts = [""] * num_choices
|
|
||||||
all_previous_token_ids = [[]] * num_choices
|
all_previous_token_ids = [[]] * num_choices
|
||||||
# For reasoning parser and tool call all enabled
|
# For reasoning parser and tool call all enabled
|
||||||
added_content_delta_arr = [False] * num_choices
|
added_content_delta_arr = [False] * num_choices
|
||||||
reasoning_end_arr = [False] * num_choices
|
reasoning_end_arr = [False] * num_choices
|
||||||
elif request.tool_choice == "required":
|
elif request.tool_choice == "required":
|
||||||
previous_texts = [""] * num_choices
|
|
||||||
all_previous_token_ids = None
|
all_previous_token_ids = None
|
||||||
else:
|
else:
|
||||||
previous_texts, all_previous_token_ids = None, None
|
all_previous_token_ids = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.reasoning_parser:
|
if self.reasoning_parser:
|
||||||
@@ -844,6 +847,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
current_token_ids=current_token_ids,
|
current_token_ids=current_token_ids,
|
||||||
delta_token_ids=output.token_ids,
|
delta_token_ids=output.token_ids,
|
||||||
request=request))
|
request=request))
|
||||||
|
|
||||||
# when only reasoning
|
# when only reasoning
|
||||||
elif self.reasoning_parser:
|
elif self.reasoning_parser:
|
||||||
delta_message = (reasoning_parser.
|
delta_message = (reasoning_parser.
|
||||||
@@ -865,6 +869,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
assert all_previous_token_ids is not None
|
assert all_previous_token_ids is not None
|
||||||
previous_texts[i] = current_text
|
previous_texts[i] = current_text
|
||||||
all_previous_token_ids[i] = current_token_ids
|
all_previous_token_ids[i] = current_token_ids
|
||||||
|
else:
|
||||||
|
# Update for comprehensive logging even in simple case
|
||||||
|
assert previous_texts is not None
|
||||||
|
previous_texts[i] += delta_text
|
||||||
|
|
||||||
# set the previous values for the next iteration
|
# set the previous values for the next iteration
|
||||||
previous_num_tokens[i] += len(output.token_ids)
|
previous_num_tokens[i] += len(output.token_ids)
|
||||||
@@ -876,6 +884,27 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
if delta_message is None:
|
if delta_message is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Log streaming delta if output logging is enabled
|
||||||
|
if self.enable_log_outputs and self.request_logger:
|
||||||
|
delta_content = ""
|
||||||
|
if delta_message.content:
|
||||||
|
delta_content = delta_message.content
|
||||||
|
elif delta_message.tool_calls:
|
||||||
|
delta_content = "".join(
|
||||||
|
tc.function.arguments
|
||||||
|
for tc in delta_message.tool_calls
|
||||||
|
if tc.function and tc.function.arguments)
|
||||||
|
|
||||||
|
if delta_content:
|
||||||
|
self.request_logger.log_outputs(
|
||||||
|
request_id=request_id,
|
||||||
|
outputs=delta_content,
|
||||||
|
output_token_ids=list(output.token_ids),
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
is_streaming=True,
|
||||||
|
delta=True,
|
||||||
|
)
|
||||||
|
|
||||||
if output.finish_reason is None:
|
if output.finish_reason is None:
|
||||||
# Send token-by-token response for each request.n
|
# Send token-by-token response for each request.n
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
@@ -994,7 +1023,27 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request_metadata.final_usage_info = UsageInfo(
|
request_metadata.final_usage_info = UsageInfo(
|
||||||
prompt_tokens=num_prompt_tokens,
|
prompt_tokens=num_prompt_tokens,
|
||||||
completion_tokens=num_completion_tokens,
|
completion_tokens=num_completion_tokens,
|
||||||
total_tokens=num_prompt_tokens + num_completion_tokens)
|
total_tokens=num_prompt_tokens + num_completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log complete streaming response if output logging is enabled
|
||||||
|
if self.enable_log_outputs and self.request_logger:
|
||||||
|
# Log the complete response for each choice
|
||||||
|
for i in range(num_choices):
|
||||||
|
full_text = (
|
||||||
|
previous_texts[i]
|
||||||
|
if previous_texts and i < len(previous_texts) else
|
||||||
|
f"<streaming_complete: {previous_num_tokens[i]} tokens>"
|
||||||
|
)
|
||||||
|
self.request_logger.log_outputs(
|
||||||
|
request_id=request_id,
|
||||||
|
outputs=full_text,
|
||||||
|
output_token_ids=
|
||||||
|
None, # Consider also logging all token IDs
|
||||||
|
finish_reason="streaming_complete",
|
||||||
|
is_streaming=True,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
@@ -1121,8 +1170,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tool_calls=[
|
tool_calls=[
|
||||||
tool_call_class(function=FunctionCall(
|
tool_call_class(function=FunctionCall(
|
||||||
name=request.tool_choice.function.name,
|
name=request.tool_choice.function.name,
|
||||||
arguments=content))
|
arguments=content,
|
||||||
])
|
))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
elif request.tool_choice and request.tool_choice == "required":
|
elif request.tool_choice and request.tool_choice == "required":
|
||||||
tool_call_class = MistralToolCall if isinstance(
|
tool_call_class = MistralToolCall if isinstance(
|
||||||
@@ -1209,12 +1260,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
finish_reason="tool_calls" if auto_tools_called else
|
finish_reason="tool_calls" if auto_tools_called else
|
||||||
output.finish_reason if output.finish_reason else "stop",
|
output.finish_reason if output.finish_reason else "stop",
|
||||||
stop_reason=output.stop_reason)
|
stop_reason=output.stop_reason)
|
||||||
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
if request.echo:
|
if request.echo:
|
||||||
last_msg_content: Union[str, list[dict[str, str]]] = ""
|
last_msg_content: Union[str, list[dict[str, str]]] = ""
|
||||||
if conversation and "content" in conversation[-1] and conversation[
|
if (conversation and "content" in conversation[-1]
|
||||||
-1].get("role") == role:
|
and conversation[-1].get("role") == role):
|
||||||
last_msg_content = conversation[-1]["content"] or ""
|
last_msg_content = conversation[-1]["content"] or ""
|
||||||
if isinstance(last_msg_content, list):
|
if isinstance(last_msg_content, list):
|
||||||
last_msg_content = "\n".join(msg['text']
|
last_msg_content = "\n".join(msg['text']
|
||||||
@@ -1251,6 +1303,40 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
kv_transfer_params=final_res.kv_transfer_params,
|
kv_transfer_params=final_res.kv_transfer_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Log complete response if output logging is enabled
|
||||||
|
if self.enable_log_outputs and self.request_logger:
|
||||||
|
for choice in choices:
|
||||||
|
output_text = ""
|
||||||
|
if choice.message.content:
|
||||||
|
output_text = choice.message.content
|
||||||
|
elif choice.message.tool_calls:
|
||||||
|
# For tool calls, log the function name and arguments
|
||||||
|
tool_call_descriptions = []
|
||||||
|
for tool_call in choice.message.tool_calls:
|
||||||
|
if hasattr(tool_call.function, "name") and hasattr(
|
||||||
|
tool_call.function, "arguments"):
|
||||||
|
tool_call_descriptions.append(
|
||||||
|
f"{tool_call.function.name}({tool_call.function.arguments})"
|
||||||
|
)
|
||||||
|
tool_calls_str = ", ".join(tool_call_descriptions)
|
||||||
|
output_text = f"[tool_calls: {tool_calls_str}]"
|
||||||
|
|
||||||
|
if output_text:
|
||||||
|
# Get the corresponding output token IDs
|
||||||
|
output_token_ids = None
|
||||||
|
if choice.index < len(final_res.outputs):
|
||||||
|
output_token_ids = final_res.outputs[
|
||||||
|
choice.index].token_ids
|
||||||
|
|
||||||
|
self.request_logger.log_outputs(
|
||||||
|
request_id=request_id,
|
||||||
|
outputs=output_text,
|
||||||
|
output_token_ids=output_token_ids,
|
||||||
|
finish_reason=choice.finish_reason,
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _get_top_logprobs(
|
def _get_top_logprobs(
|
||||||
@@ -1258,15 +1344,16 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
should_return_as_token_id: bool) -> list[ChatCompletionLogProb]:
|
should_return_as_token_id: bool) -> list[ChatCompletionLogProb]:
|
||||||
return [
|
return [
|
||||||
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
ChatCompletionLogProb(
|
||||||
p[1],
|
token=(token := self._get_decoded_token(
|
||||||
p[0],
|
p[1],
|
||||||
tokenizer,
|
p[0],
|
||||||
return_as_token_id=should_return_as_token_id)),
|
tokenizer,
|
||||||
logprob=max(p[1].logprob, -9999.0),
|
return_as_token_id=should_return_as_token_id,
|
||||||
bytes=list(
|
)),
|
||||||
token.encode("utf-8", errors="replace")))
|
logprob=max(p[1].logprob, -9999.0),
|
||||||
for i, p in enumerate(logprobs.items())
|
bytes=list(token.encode("utf-8", errors="replace")),
|
||||||
|
) for i, p in enumerate(logprobs.items())
|
||||||
if top_logprobs and i < top_logprobs
|
if top_logprobs and i < top_logprobs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
tool_server: Optional[ToolServer] = None,
|
tool_server: Optional[ToolServer] = None,
|
||||||
enable_prompt_tokens_details: bool = False,
|
enable_prompt_tokens_details: bool = False,
|
||||||
enable_force_include_usage: bool = False,
|
enable_force_include_usage: bool = False,
|
||||||
|
enable_log_outputs: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
engine_client=engine_client,
|
engine_client=engine_client,
|
||||||
@@ -77,6 +78,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.chat_template_content_format: Final = chat_template_content_format
|
self.chat_template_content_format: Final = chat_template_content_format
|
||||||
|
self.enable_log_outputs = enable_log_outputs
|
||||||
|
|
||||||
self.reasoning_parser: Optional[Callable[[AnyTokenizer],
|
self.reasoning_parser: Optional[Callable[[AnyTokenizer],
|
||||||
ReasoningParser]] = None
|
ReasoningParser]] = None
|
||||||
@@ -428,6 +430,24 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Log complete response if output logging is enabled
|
||||||
|
if self.enable_log_outputs and self.request_logger:
|
||||||
|
output_text = ""
|
||||||
|
if content:
|
||||||
|
output_text = content
|
||||||
|
elif reasoning_content:
|
||||||
|
output_text = f"[reasoning: {reasoning_content}]"
|
||||||
|
|
||||||
|
if output_text:
|
||||||
|
self.request_logger.log_outputs(
|
||||||
|
request_id=request.request_id,
|
||||||
|
outputs=output_text,
|
||||||
|
output_token_ids=final_output.token_ids,
|
||||||
|
finish_reason=final_output.finish_reason,
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
if request.store:
|
if request.store:
|
||||||
async with self.response_store_lock:
|
async with self.response_store_lock:
|
||||||
stored_response = self.response_store.get(response.id)
|
stored_response = self.response_store.get(response.id)
|
||||||
|
|||||||
Reference in New Issue
Block a user