feat: Add --enable-log-outputs flag for logging model generations (#20707)
Signed-off-by: Adrian Garcia <adrian.garcia@inceptionai.ai>
This commit is contained in:
committed by
GitHub
parent
82216dc21f
commit
8e8e0b6af1
@@ -73,6 +73,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_parser: Optional[str] = None,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
enable_log_outputs: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
@@ -84,6 +85,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self.response_role = response_role
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.enable_log_outputs = enable_log_outputs
|
||||
|
||||
# set up tool use
|
||||
self.enable_auto_tools: bool = enable_auto_tools
|
||||
@@ -489,20 +491,21 @@ class OpenAIServingChat(OpenAIServing):
|
||||
all_previous_token_ids: Optional[list[list[int]]]
|
||||
function_name_returned = [False] * num_choices
|
||||
|
||||
# Always track previous_texts for comprehensive output logging
|
||||
previous_texts = [""] * num_choices
|
||||
|
||||
# Only one of these will be used, thus previous_texts and
|
||||
# all_previous_token_ids will not be used twice in the same iteration.
|
||||
if tool_choice_auto or self.reasoning_parser:
|
||||
# These are only required in "auto" tool choice case
|
||||
previous_texts = [""] * num_choices
|
||||
all_previous_token_ids = [[]] * num_choices
|
||||
# For reasoning parser and tool call all enabled
|
||||
added_content_delta_arr = [False] * num_choices
|
||||
reasoning_end_arr = [False] * num_choices
|
||||
elif request.tool_choice == "required":
|
||||
previous_texts = [""] * num_choices
|
||||
all_previous_token_ids = None
|
||||
else:
|
||||
previous_texts, all_previous_token_ids = None, None
|
||||
all_previous_token_ids = None
|
||||
|
||||
try:
|
||||
if self.reasoning_parser:
|
||||
@@ -844,6 +847,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids,
|
||||
request=request))
|
||||
|
||||
# when only reasoning
|
||||
elif self.reasoning_parser:
|
||||
delta_message = (reasoning_parser.
|
||||
@@ -865,6 +869,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
assert all_previous_token_ids is not None
|
||||
previous_texts[i] = current_text
|
||||
all_previous_token_ids[i] = current_token_ids
|
||||
else:
|
||||
# Update for comprehensive logging even in simple case
|
||||
assert previous_texts is not None
|
||||
previous_texts[i] += delta_text
|
||||
|
||||
# set the previous values for the next iteration
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
@@ -876,6 +884,27 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if delta_message is None:
|
||||
continue
|
||||
|
||||
# Log streaming delta if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
delta_content = ""
|
||||
if delta_message.content:
|
||||
delta_content = delta_message.content
|
||||
elif delta_message.tool_calls:
|
||||
delta_content = "".join(
|
||||
tc.function.arguments
|
||||
for tc in delta_message.tool_calls
|
||||
if tc.function and tc.function.arguments)
|
||||
|
||||
if delta_content:
|
||||
self.request_logger.log_outputs(
|
||||
request_id=request_id,
|
||||
outputs=delta_content,
|
||||
output_token_ids=list(output.token_ids),
|
||||
finish_reason=output.finish_reason,
|
||||
is_streaming=True,
|
||||
delta=True,
|
||||
)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
@@ -994,7 +1023,27 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_completion_tokens,
|
||||
total_tokens=num_prompt_tokens + num_completion_tokens)
|
||||
total_tokens=num_prompt_tokens + num_completion_tokens,
|
||||
)
|
||||
|
||||
# Log complete streaming response if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
# Log the complete response for each choice
|
||||
for i in range(num_choices):
|
||||
full_text = (
|
||||
previous_texts[i]
|
||||
if previous_texts and i < len(previous_texts) else
|
||||
f"<streaming_complete: {previous_num_tokens[i]} tokens>"
|
||||
)
|
||||
self.request_logger.log_outputs(
|
||||
request_id=request_id,
|
||||
outputs=full_text,
|
||||
output_token_ids=
|
||||
None, # Consider also logging all token IDs
|
||||
finish_reason="streaming_complete",
|
||||
is_streaming=True,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
@@ -1121,8 +1170,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_calls=[
|
||||
tool_call_class(function=FunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=content))
|
||||
])
|
||||
arguments=content,
|
||||
))
|
||||
],
|
||||
)
|
||||
|
||||
elif request.tool_choice and request.tool_choice == "required":
|
||||
tool_call_class = MistralToolCall if isinstance(
|
||||
@@ -1209,12 +1260,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
finish_reason="tool_calls" if auto_tools_called else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason)
|
||||
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo:
|
||||
last_msg_content: Union[str, list[dict[str, str]]] = ""
|
||||
if conversation and "content" in conversation[-1] and conversation[
|
||||
-1].get("role") == role:
|
||||
if (conversation and "content" in conversation[-1]
|
||||
and conversation[-1].get("role") == role):
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
if isinstance(last_msg_content, list):
|
||||
last_msg_content = "\n".join(msg['text']
|
||||
@@ -1251,6 +1303,40 @@ class OpenAIServingChat(OpenAIServing):
|
||||
kv_transfer_params=final_res.kv_transfer_params,
|
||||
)
|
||||
|
||||
# Log complete response if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
for choice in choices:
|
||||
output_text = ""
|
||||
if choice.message.content:
|
||||
output_text = choice.message.content
|
||||
elif choice.message.tool_calls:
|
||||
# For tool calls, log the function name and arguments
|
||||
tool_call_descriptions = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if hasattr(tool_call.function, "name") and hasattr(
|
||||
tool_call.function, "arguments"):
|
||||
tool_call_descriptions.append(
|
||||
f"{tool_call.function.name}({tool_call.function.arguments})"
|
||||
)
|
||||
tool_calls_str = ", ".join(tool_call_descriptions)
|
||||
output_text = f"[tool_calls: {tool_calls_str}]"
|
||||
|
||||
if output_text:
|
||||
# Get the corresponding output token IDs
|
||||
output_token_ids = None
|
||||
if choice.index < len(final_res.outputs):
|
||||
output_token_ids = final_res.outputs[
|
||||
choice.index].token_ids
|
||||
|
||||
self.request_logger.log_outputs(
|
||||
request_id=request_id,
|
||||
outputs=output_text,
|
||||
output_token_ids=output_token_ids,
|
||||
finish_reason=choice.finish_reason,
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _get_top_logprobs(
|
||||
@@ -1258,15 +1344,16 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tokenizer: AnyTokenizer,
|
||||
should_return_as_token_id: bool) -> list[ChatCompletionLogProb]:
|
||||
return [
|
||||
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
||||
p[1],
|
||||
p[0],
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id)),
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
bytes=list(
|
||||
token.encode("utf-8", errors="replace")))
|
||||
for i, p in enumerate(logprobs.items())
|
||||
ChatCompletionLogProb(
|
||||
token=(token := self._get_decoded_token(
|
||||
p[1],
|
||||
p[0],
|
||||
tokenizer,
|
||||
return_as_token_id=should_return_as_token_id,
|
||||
)),
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
bytes=list(token.encode("utf-8", errors="replace")),
|
||||
) for i, p in enumerate(logprobs.items())
|
||||
if top_logprobs and i < top_logprobs
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user