[frontend][gptoss] Add per turn stats into Harmony Context (#25061)
Signed-off-by: lacora <hyelacora@gmail.com> Co-authored-by: Ye Hu <yehu@fb.com>
This commit is contained in:
@@ -45,21 +45,36 @@ def _map_tool_name_to_tool_type(tool_name: str) -> str:
|
||||
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
|
||||
|
||||
|
||||
class TurnTokens:
|
||||
"""Tracks token counts for a single conversation turn."""
|
||||
class TurnMetrics:
|
||||
"""Tracks token and toolcall details for a single conversation turn."""
|
||||
|
||||
def __init__(self, input_tokens=0, output_tokens=0):
|
||||
def __init__(
|
||||
self,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
cached_input_tokens=0,
|
||||
tool_output_tokens=0,
|
||||
):
|
||||
self.input_tokens = input_tokens
|
||||
self.output_tokens = output_tokens
|
||||
self.cached_input_tokens = cached_input_tokens
|
||||
self.tool_output_tokens = tool_output_tokens
|
||||
|
||||
def reset(self):
|
||||
"""Reset counters for a new turn."""
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.cached_input_tokens = 0
|
||||
self.tool_output_tokens = 0
|
||||
|
||||
def copy(self):
|
||||
"""Create a copy of this turn's token counts."""
|
||||
return TurnTokens(self.input_tokens, self.output_tokens)
|
||||
return TurnMetrics(
|
||||
self.input_tokens,
|
||||
self.output_tokens,
|
||||
self.cached_input_tokens,
|
||||
self.tool_output_tokens,
|
||||
)
|
||||
|
||||
|
||||
class ConversationContext(ABC):
|
||||
@@ -102,6 +117,8 @@ class SimpleContext(ConversationContext):
|
||||
self.num_cached_tokens = 0
|
||||
# todo num_reasoning_tokens is not implemented yet.
|
||||
self.num_reasoning_tokens = 0
|
||||
# not implemented yet for SimpleContext
|
||||
self.all_turn_metrics = []
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
self.last_output = output
|
||||
@@ -154,8 +171,9 @@ class HarmonyContext(ConversationContext):
|
||||
self.num_tool_output_tokens = 0
|
||||
|
||||
# Turn tracking - replaces multiple individual tracking variables
|
||||
self.current_turn = TurnTokens()
|
||||
self.previous_turn = TurnTokens()
|
||||
self.current_turn_metrics = TurnMetrics()
|
||||
# Track metrics for all turns
|
||||
self.all_turn_metrics: list[TurnMetrics] = []
|
||||
self.is_first_turn = True
|
||||
self.first_tok_of_message = True # For streaming support
|
||||
|
||||
@@ -173,11 +191,10 @@ class HarmonyContext(ConversationContext):
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self._update_prefill_token_usage(output)
|
||||
# Reset current turn output tokens for this turn
|
||||
self.current_turn.output_tokens = 0
|
||||
self._update_decode_token_usage(output)
|
||||
# Move current turn to previous turn for next turn's calculations
|
||||
self.previous_turn = self.current_turn.copy()
|
||||
# Append current turn to all turn list for next turn's calculations
|
||||
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
||||
self.current_turn_metrics.reset()
|
||||
# append_output is called only once before tool calling
|
||||
# in non-streaming case
|
||||
# so we can append all the parser messages to _messages
|
||||
@@ -213,20 +230,21 @@ class HarmonyContext(ConversationContext):
|
||||
logger.error("RequestOutput appended contains no prompt_token_ids.")
|
||||
|
||||
# Update current turn input tokens
|
||||
self.current_turn.input_tokens = this_turn_input_tokens
|
||||
self.current_turn_metrics.input_tokens = this_turn_input_tokens
|
||||
self.num_prompt_tokens += this_turn_input_tokens
|
||||
|
||||
# Calculate tool tokens (except on first turn)
|
||||
if self.is_first_turn:
|
||||
self.is_first_turn = False
|
||||
else:
|
||||
previous_turn = self.all_turn_metrics[-1]
|
||||
# start counting tool after first turn
|
||||
# tool tokens = this turn prefill - last turn prefill -
|
||||
# last turn decode
|
||||
this_turn_tool_tokens = (
|
||||
self.current_turn.input_tokens
|
||||
- self.previous_turn.input_tokens
|
||||
- self.previous_turn.output_tokens
|
||||
self.current_turn_metrics.input_tokens
|
||||
- previous_turn.input_tokens
|
||||
- previous_turn.output_tokens
|
||||
)
|
||||
|
||||
# Handle negative tool token counts (shouldn't happen in normal
|
||||
@@ -237,17 +255,20 @@ class HarmonyContext(ConversationContext):
|
||||
"(current_input=%d, previous_input=%d, "
|
||||
"previous_output=%d). Setting to 0.",
|
||||
this_turn_tool_tokens,
|
||||
self.current_turn.input_tokens,
|
||||
self.previous_turn.input_tokens,
|
||||
self.previous_turn.output_tokens,
|
||||
self.current_turn_metrics.input_tokens,
|
||||
previous_turn.input_tokens,
|
||||
previous_turn.output_tokens,
|
||||
)
|
||||
this_turn_tool_tokens = 0
|
||||
|
||||
self.num_tool_output_tokens += this_turn_tool_tokens
|
||||
self.current_turn_metrics.tool_output_tokens = this_turn_tool_tokens
|
||||
|
||||
# Update cached tokens
|
||||
if output.num_cached_tokens is not None:
|
||||
self.num_cached_tokens += output.num_cached_tokens
|
||||
num_cached_token = output.num_cached_tokens
|
||||
if num_cached_token is not None:
|
||||
self.num_cached_tokens += num_cached_token
|
||||
self.current_turn_metrics.cached_input_tokens = num_cached_token
|
||||
|
||||
def _update_decode_token_usage(self, output: RequestOutput) -> int:
|
||||
"""Update token usage statistics for the decode phase of generation.
|
||||
@@ -272,7 +293,7 @@ class HarmonyContext(ConversationContext):
|
||||
# only keep last round
|
||||
updated_output_token_count += len(completion_output.token_ids)
|
||||
self.num_output_tokens += updated_output_token_count
|
||||
self.current_turn.output_tokens += updated_output_token_count
|
||||
self.current_turn_metrics.output_tokens += updated_output_token_count
|
||||
return updated_output_token_count
|
||||
|
||||
@property
|
||||
@@ -452,7 +473,6 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
# so we only want to add the prompt tokens once for each message.
|
||||
if self.first_tok_of_message:
|
||||
self._update_prefill_token_usage(output)
|
||||
self.current_turn.output_tokens = 0
|
||||
# Reset self.first_tok_of_message if needed:
|
||||
# if the current token is the last one of the current message
|
||||
# (finished=True), then the next token processed will mark the
|
||||
@@ -464,7 +484,8 @@ class StreamingHarmonyContext(HarmonyContext):
|
||||
|
||||
# For streaming, update previous turn when message is complete
|
||||
if output.finished:
|
||||
self.previous_turn = self.current_turn.copy()
|
||||
self.all_turn_metrics.append(self.current_turn_metrics.copy())
|
||||
self.current_turn_metrics.reset()
|
||||
# Check if the current token is part of reasoning content
|
||||
self._update_num_reasoning_tokens()
|
||||
self.last_tok = tok
|
||||
|
||||
Reference in New Issue
Block a user