[torch.compile] add a flag to track batchsize statistics (#11059)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-10 12:40:52 -08:00
committed by GitHub
parent e739194926
commit 75f89dc44c
4 changed files with 37 additions and 1 deletions

View File

@@ -56,6 +56,7 @@ class FlashAttentionMetadata:
seq_start_loc: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
num_input_tokens: int = 0 # Number of tokens including padding.
class FlashAttentionImpl(AttentionImpl):