[Frontend][Responses API] Support reporting tool output tokens and fix reasoning token count (#24285)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
committed by
GitHub
parent
fb691ee4e7
commit
a3645ed94d
@@ -3,7 +3,6 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
@@ -21,6 +20,23 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TurnTokens:
|
||||
"""Tracks token counts for a single conversation turn."""
|
||||
|
||||
def __init__(self, input_tokens=0, output_tokens=0):
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
|
||||
def reset(self):
|
||||
"""Reset counters for a new turn."""
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
|
||||
def copy(self):
|
||||
"""Create a copy of this turn's token counts."""
|
||||
return TurnTokens(self.input_tokens, self.output_tokens)
|
||||
|
||||
|
||||
class ConversationContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
@@ -92,52 +108,124 @@ class HarmonyContext(ConversationContext):
|
||||
self.num_init_messages = len(messages)
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
# TODO(woosuk): Implement the following fields.
|
||||
self.num_cached_tokens = 0
|
||||
self.num_reasoning_tokens = 0
|
||||
self.num_tool_output_tokens = 0
|
||||
|
||||
def _update_num_prompt_tokens(self, output: RequestOutput):
|
||||
if output.prompt_token_ids and len(output.prompt_token_ids) > 0:
|
||||
# NOTE: with built-in tools, there might be multiple rounds in
|
||||
# the conversation, with the full conversation being resent
|
||||
# as new prompt each time. Hence the sum.
|
||||
self.num_prompt_tokens += len(output.prompt_token_ids)
|
||||
# Turn tracking - replaces multiple individual tracking variables
|
||||
self.current_turn = TurnTokens()
|
||||
self.previous_turn = TurnTokens()
|
||||
self.is_first_turn = True
|
||||
self.first_tok_of_message = True # For streaming support
|
||||
|
||||
def _update_num_cached_tokens(self, output: RequestOutput):
|
||||
if output.num_cached_tokens is not None:
|
||||
#Similar to num_prompt_tokens
|
||||
self.num_cached_tokens += output.num_cached_tokens
|
||||
|
||||
def _update_num_output_tokens(self, token_ids: Sequence[int]):
|
||||
self.num_output_tokens += len(token_ids)
|
||||
|
||||
def _update_num_reasoning_tokens(self, token_ids: Sequence[int]):
|
||||
# Count tokens that are part of reasoning content (analysis channel
|
||||
# or tool-directed messages like python/browser calls)
|
||||
is_analysis = self.parser.current_channel == "analysis"
|
||||
is_tool_call = (self.parser.current_recipient is not None and
|
||||
(self.parser.current_recipient.startswith("python") or
|
||||
self.parser.current_recipient.startswith("browser.")))
|
||||
if is_analysis or is_tool_call:
|
||||
self.num_reasoning_tokens += len(token_ids)
|
||||
def _update_num_reasoning_tokens(self):
|
||||
# Count all analysis and commentary channels as reasoning tokens
|
||||
if self.parser.current_channel in {"analysis", "commentary"}:
|
||||
self.num_reasoning_tokens += 1
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
if isinstance(output, RequestOutput):
|
||||
self._update_num_prompt_tokens(output)
|
||||
self._update_num_cached_tokens(output)
|
||||
output_token_ids = output.outputs[0].token_ids
|
||||
self._update_num_output_tokens(output_token_ids)
|
||||
self.parser = get_streamable_parser_for_assistant()
|
||||
for token_id in output_token_ids:
|
||||
self.parser.process(token_id)
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens([token_id])
|
||||
self._update_num_reasoning_tokens()
|
||||
self._update_prefill_token_usage(output)
|
||||
# Reset current turn output tokens for this turn
|
||||
self.current_turn.output_tokens = 0
|
||||
self._update_decode_token_usage(output)
|
||||
# Move current turn to previous turn for next turn's calculations
|
||||
self.previous_turn = self.current_turn.copy()
|
||||
output_msgs = self.parser.messages
|
||||
else:
|
||||
# Tool output.
|
||||
output_msgs = output
|
||||
self._messages.extend(output_msgs)
|
||||
|
||||
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
|
||||
"""Update token usage statistics for the prefill phase of generation.
|
||||
|
||||
The prefill phase processes the input prompt tokens. This method:
|
||||
1. Counts the prompt tokens for this turn
|
||||
2. Calculates tool output tokens for multi-turn conversations
|
||||
3. Updates cached token counts
|
||||
4. Tracks state for next turn calculations
|
||||
|
||||
Tool output tokens are calculated as:
|
||||
current_prompt_tokens - last_turn_prompt_tokens -
|
||||
last_turn_output_tokens
|
||||
This represents tokens added between turns (typically tool responses).
|
||||
|
||||
Args:
|
||||
output: The RequestOutput containing prompt token information
|
||||
"""
|
||||
if output.prompt_token_ids is not None:
|
||||
this_turn_input_tokens = len(output.prompt_token_ids)
|
||||
else:
|
||||
this_turn_input_tokens = 0
|
||||
logger.error(
|
||||
"RequestOutput appended contains no prompt_token_ids.")
|
||||
|
||||
# Update current turn input tokens
|
||||
self.current_turn.input_tokens = this_turn_input_tokens
|
||||
self.num_prompt_tokens += this_turn_input_tokens
|
||||
|
||||
# Calculate tool tokens (except on first turn)
|
||||
if self.is_first_turn:
|
||||
self.is_first_turn = False
|
||||
else:
|
||||
# start counting tool after first turn
|
||||
# tool tokens = this turn prefill - last turn prefill -
|
||||
# last turn decode
|
||||
this_turn_tool_tokens = (self.current_turn.input_tokens -
|
||||
self.previous_turn.input_tokens -
|
||||
self.previous_turn.output_tokens)
|
||||
|
||||
# Handle negative tool token counts (shouldn't happen in normal
|
||||
# cases)
|
||||
if this_turn_tool_tokens < 0:
|
||||
logger.error(
|
||||
"Negative tool output tokens calculated: %d "
|
||||
"(current_input=%d, previous_input=%d, "
|
||||
"previous_output=%d). Setting to 0.",
|
||||
this_turn_tool_tokens, self.current_turn.input_tokens,
|
||||
self.previous_turn.input_tokens,
|
||||
self.previous_turn.output_tokens)
|
||||
this_turn_tool_tokens = 0
|
||||
|
||||
self.num_tool_output_tokens += this_turn_tool_tokens
|
||||
|
||||
# Update cached tokens
|
||||
if output.num_cached_tokens is not None:
|
||||
self.num_cached_tokens += output.num_cached_tokens
|
||||
|
||||
def _update_decode_token_usage(self, output: RequestOutput) -> int:
|
||||
"""Update token usage statistics for the decode phase of generation.
|
||||
|
||||
The decode phase processes the generated output tokens. This method:
|
||||
1. Counts output tokens from all completion outputs
|
||||
2. Updates the total output token count
|
||||
3. Tracks tokens generated in the current turn
|
||||
|
||||
In streaming mode, this is called for each token generated.
|
||||
In non-streaming mode, this is called once with all output tokens.
|
||||
|
||||
Args:
|
||||
output: The RequestOutput containing generated token information
|
||||
|
||||
Returns:
|
||||
int: Number of output tokens processed in this call
|
||||
"""
|
||||
updated_output_token_count = 0
|
||||
if output.outputs:
|
||||
for completion_output in output.outputs:
|
||||
# only keep last round
|
||||
updated_output_token_count += len(completion_output.token_ids)
|
||||
self.num_output_tokens += updated_output_token_count
|
||||
self.current_turn.output_tokens += updated_output_token_count
|
||||
return updated_output_token_count
|
||||
|
||||
@property
|
||||
def messages(self) -> list:
|
||||
return self._messages
|
||||
@@ -231,8 +319,8 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
# append_output is called for each output token in streaming case,
|
||||
# so we only want to add the prompt tokens once for each message.
|
||||
if self.first_tok_of_message:
|
||||
self._update_num_prompt_tokens(output)
|
||||
self._update_num_cached_tokens(output)
|
||||
self._update_prefill_token_usage(output)
|
||||
self.current_turn.output_tokens = 0
|
||||
# Reset self.first_tok_of_message if needed:
|
||||
# if the current token is the last one of the current message
|
||||
# (finished=True), then the next token processed will mark the
|
||||
@@ -240,9 +328,13 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
self.first_tok_of_message = output.finished
|
||||
for tok in output.outputs[0].token_ids:
|
||||
self.parser.process(tok)
|
||||
self._update_num_output_tokens(output.outputs[0].token_ids)
|
||||
self._update_decode_token_usage(output)
|
||||
|
||||
# For streaming, update previous turn when message is complete
|
||||
if output.finished:
|
||||
self.previous_turn = self.current_turn.copy()
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens(output.outputs[0].token_ids)
|
||||
self._update_num_reasoning_tokens()
|
||||
self.last_tok = tok
|
||||
else:
|
||||
# Handle the case of tool output in direct message format
|
||||
|
||||
Reference in New Issue
Block a user