Fix UnembedMetrics to correctly count FLOPs for the unembedding (LM head) layer. The bug: UnembedMetrics used total_num_tokens() which counts all tokens in the batch for projection flops, vocab projections are run on just the last token for the autoregressive use case. Co-authored-by: Omar Mohamed Khalil <omarkhalil@meta.com>
This commit is contained in:
@@ -110,6 +110,14 @@ class ExecutionContext:
|
||||
"""Total sum of (num_tokens * context_len) across all requests."""
|
||||
return self.prefill_token_context_product + self.decode_token_context_product
|
||||
|
||||
def num_logits_tokens(self) -> int:
|
||||
"""Number of tokens that require logits computation (unembedding).
|
||||
|
||||
For prefill, only the last token per request needs logits.
|
||||
For decode, all tokens need logits.
|
||||
"""
|
||||
return self.num_prefill_requests + self.decode_num_tokens
|
||||
|
||||
@classmethod
|
||||
def from_single_request(
|
||||
cls, num_tokens: int, context_len: int, is_prefill: bool
|
||||
@@ -906,7 +914,7 @@ class UnembedMetrics(ComponentMetrics):
|
||||
) -> dict[str, int]:
|
||||
"""Calculate flops breakdown for unembedding layer."""
|
||||
D, V = self.hidden_size, self.vocab_size
|
||||
T = ctx.total_num_tokens()
|
||||
T = ctx.num_logits_tokens()
|
||||
|
||||
if per_gpu:
|
||||
V //= self.tp_size
|
||||
@@ -920,7 +928,7 @@ class UnembedMetrics(ComponentMetrics):
|
||||
) -> dict[str, int]:
|
||||
"""Calculate read memory traffic for unembedding layer."""
|
||||
D, V = self.hidden_size, self.vocab_size
|
||||
T = ctx.total_num_tokens()
|
||||
T = ctx.num_logits_tokens()
|
||||
|
||||
if per_gpu:
|
||||
V //= self.tp_size
|
||||
@@ -935,7 +943,7 @@ class UnembedMetrics(ComponentMetrics):
|
||||
) -> dict[str, int]:
|
||||
"""Calculate write memory traffic for unembedding layer."""
|
||||
V = self.vocab_size
|
||||
T = ctx.total_num_tokens()
|
||||
T = ctx.num_logits_tokens()
|
||||
|
||||
if per_gpu:
|
||||
V //= self.tp_size
|
||||
|
||||
Reference in New Issue
Block a user