[Metrics][MFU] Fix UnembedMetrics FLOP overcounting for prefill (#33045) (#33045)

Fix UnembedMetrics to correctly count FLOPs for the unembedding (LM head) layer.

The bug: UnembedMetrics used total_num_tokens() which counts all tokens in the
batch for projection flops, vocab projections are run on just the last token for the
autoregressive use case.

Co-authored-by: Omar Mohamed Khalil <omarkhalil@meta.com>
This commit is contained in:
omkhalil
2026-01-27 10:16:49 -05:00
committed by GitHub
parent 492a7983dd
commit 5ec44056f7

View File

@@ -110,6 +110,14 @@ class ExecutionContext:
"""Total sum of (num_tokens * context_len) across all requests."""
return self.prefill_token_context_product + self.decode_token_context_product
def num_logits_tokens(self) -> int:
"""Number of tokens that require logits computation (unembedding).
For prefill, only the last token per request needs logits.
For decode, all tokens need logits.
"""
return self.num_prefill_requests + self.decode_num_tokens
@classmethod
def from_single_request(
cls, num_tokens: int, context_len: int, is_prefill: bool
@@ -906,7 +914,7 @@ class UnembedMetrics(ComponentMetrics):
) -> dict[str, int]:
"""Calculate flops breakdown for unembedding layer."""
D, V = self.hidden_size, self.vocab_size
T = ctx.total_num_tokens()
T = ctx.num_logits_tokens()
if per_gpu:
V //= self.tp_size
@@ -920,7 +928,7 @@ class UnembedMetrics(ComponentMetrics):
) -> dict[str, int]:
"""Calculate read memory traffic for unembedding layer."""
D, V = self.hidden_size, self.vocab_size
T = ctx.total_num_tokens()
T = ctx.num_logits_tokens()
if per_gpu:
V //= self.tp_size
@@ -935,7 +943,7 @@ class UnembedMetrics(ComponentMetrics):
) -> dict[str, int]:
"""Calculate write memory traffic for unembedding layer."""
V = self.vocab_size
T = ctx.total_num_tokens()
T = ctx.num_logits_tokens()
if per_gpu:
V //= self.tp_size