[torch.compile] add a flag to track batchsize statistics (#11059)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -56,6 +56,7 @@ class FlashAttentionMetadata:
|
||||
seq_start_loc: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
Reference in New Issue
Block a user