[Frontend][Responses API] Support reporting tool output tokens and fix reasoning token count (#24285)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Ye (Charlotte) Qi
2025-09-06 13:27:15 -07:00
committed by GitHub
parent fb691ee4e7
commit a3645ed94d
4 changed files with 557 additions and 36 deletions

View File

@@ -3,7 +3,6 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union
@@ -21,6 +20,23 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class TurnTokens:
"""Tracks token counts for a single conversation turn."""
def __init__(self, input_tokens=0, output_tokens=0):
self.input_tokens = input_tokens
self.output_tokens = output_tokens
def reset(self):
"""Reset counters for a new turn."""
self.input_tokens = 0
self.output_tokens = 0
def copy(self):
"""Create a copy of this turn's token counts."""
return TurnTokens(self.input_tokens, self.output_tokens)
class ConversationContext(ABC):
@abstractmethod
@@ -92,52 +108,124 @@ class HarmonyContext(ConversationContext):
self.num_init_messages = len(messages)
self.num_prompt_tokens = 0
self.num_output_tokens = 0
# TODO(woosuk): Implement the following fields.
self.num_cached_tokens = 0
self.num_reasoning_tokens = 0
self.num_tool_output_tokens = 0
def _update_num_prompt_tokens(self, output: RequestOutput):
if output.prompt_token_ids and len(output.prompt_token_ids) > 0:
# NOTE: with built-in tools, there might be multiple rounds in
# the conversation, with the full conversation being resent
# as new prompt each time. Hence the sum.
self.num_prompt_tokens += len(output.prompt_token_ids)
# Turn tracking - replaces multiple individual tracking variables
self.current_turn = TurnTokens()
self.previous_turn = TurnTokens()
self.is_first_turn = True
self.first_tok_of_message = True # For streaming support
def _update_num_cached_tokens(self, output: RequestOutput):
if output.num_cached_tokens is not None:
#Similar to num_prompt_tokens
self.num_cached_tokens += output.num_cached_tokens
def _update_num_output_tokens(self, token_ids: Sequence[int]):
self.num_output_tokens += len(token_ids)
def _update_num_reasoning_tokens(self, token_ids: Sequence[int]):
# Count tokens that are part of reasoning content (analysis channel
# or tool-directed messages like python/browser calls)
is_analysis = self.parser.current_channel == "analysis"
is_tool_call = (self.parser.current_recipient is not None and
(self.parser.current_recipient.startswith("python") or
self.parser.current_recipient.startswith("browser.")))
if is_analysis or is_tool_call:
self.num_reasoning_tokens += len(token_ids)
def _update_num_reasoning_tokens(self):
# Count all analysis and commentary channels as reasoning tokens
if self.parser.current_channel in {"analysis", "commentary"}:
self.num_reasoning_tokens += 1
def append_output(self, output) -> None:
if isinstance(output, RequestOutput):
self._update_num_prompt_tokens(output)
self._update_num_cached_tokens(output)
output_token_ids = output.outputs[0].token_ids
self._update_num_output_tokens(output_token_ids)
self.parser = get_streamable_parser_for_assistant()
for token_id in output_token_ids:
self.parser.process(token_id)
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens([token_id])
self._update_num_reasoning_tokens()
self._update_prefill_token_usage(output)
# Reset current turn output tokens for this turn
self.current_turn.output_tokens = 0
self._update_decode_token_usage(output)
# Move current turn to previous turn for next turn's calculations
self.previous_turn = self.current_turn.copy()
output_msgs = self.parser.messages
else:
# Tool output.
output_msgs = output
self._messages.extend(output_msgs)
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
"""Update token usage statistics for the prefill phase of generation.
The prefill phase processes the input prompt tokens. This method:
1. Counts the prompt tokens for this turn
2. Calculates tool output tokens for multi-turn conversations
3. Updates cached token counts
4. Tracks state for next turn calculations
Tool output tokens are calculated as:
current_prompt_tokens - last_turn_prompt_tokens -
last_turn_output_tokens
This represents tokens added between turns (typically tool responses).
Args:
output: The RequestOutput containing prompt token information
"""
if output.prompt_token_ids is not None:
this_turn_input_tokens = len(output.prompt_token_ids)
else:
this_turn_input_tokens = 0
logger.error(
"RequestOutput appended contains no prompt_token_ids.")
# Update current turn input tokens
self.current_turn.input_tokens = this_turn_input_tokens
self.num_prompt_tokens += this_turn_input_tokens
# Calculate tool tokens (except on first turn)
if self.is_first_turn:
self.is_first_turn = False
else:
# start counting tool after first turn
# tool tokens = this turn prefill - last turn prefill -
# last turn decode
this_turn_tool_tokens = (self.current_turn.input_tokens -
self.previous_turn.input_tokens -
self.previous_turn.output_tokens)
# Handle negative tool token counts (shouldn't happen in normal
# cases)
if this_turn_tool_tokens < 0:
logger.error(
"Negative tool output tokens calculated: %d "
"(current_input=%d, previous_input=%d, "
"previous_output=%d). Setting to 0.",
this_turn_tool_tokens, self.current_turn.input_tokens,
self.previous_turn.input_tokens,
self.previous_turn.output_tokens)
this_turn_tool_tokens = 0
self.num_tool_output_tokens += this_turn_tool_tokens
# Update cached tokens
if output.num_cached_tokens is not None:
self.num_cached_tokens += output.num_cached_tokens
def _update_decode_token_usage(self, output: RequestOutput) -> int:
"""Update token usage statistics for the decode phase of generation.
The decode phase processes the generated output tokens. This method:
1. Counts output tokens from all completion outputs
2. Updates the total output token count
3. Tracks tokens generated in the current turn
In streaming mode, this is called for each token generated.
In non-streaming mode, this is called once with all output tokens.
Args:
output: The RequestOutput containing generated token information
Returns:
int: Number of output tokens processed in this call
"""
updated_output_token_count = 0
if output.outputs:
for completion_output in output.outputs:
# only keep last round
updated_output_token_count += len(completion_output.token_ids)
self.num_output_tokens += updated_output_token_count
self.current_turn.output_tokens += updated_output_token_count
return updated_output_token_count
@property
def messages(self) -> list:
return self._messages
@@ -231,8 +319,8 @@ class StreamingHarmonyContext(HarmonyContext):
# append_output is called for each output token in streaming case,
# so we only want to add the prompt tokens once for each message.
if self.first_tok_of_message:
self._update_num_prompt_tokens(output)
self._update_num_cached_tokens(output)
self._update_prefill_token_usage(output)
self.current_turn.output_tokens = 0
# Reset self.first_tok_of_message if needed:
# if the current token is the last one of the current message
# (finished=True), then the next token processed will mark the
@@ -240,9 +328,13 @@ class StreamingHarmonyContext(HarmonyContext):
self.first_tok_of_message = output.finished
for tok in output.outputs[0].token_ids:
self.parser.process(tok)
self._update_num_output_tokens(output.outputs[0].token_ids)
self._update_decode_token_usage(output)
# For streaming, update previous turn when message is complete
if output.finished:
self.previous_turn = self.current_turn.copy()
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens(output.outputs[0].token_ids)
self._update_num_reasoning_tokens()
self.last_tok = tok
else:
# Handle the case of tool output in direct message format