[V1] Further reduce CPU overheads in flash-attn (#10989)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2024-12-09 12:38:46 -08:00
committed by GitHub
parent edc4fa3188
commit 3b61cb450d
2 changed files with 28 additions and 7 deletions

View File

@@ -138,14 +138,25 @@ class FlashAttentionImpl(AttentionImpl):
# Profiling run.
return output
num_actual_tokens = attn_metadata.num_actual_tokens
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the input keys and values and store them in the cache.
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens] and
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
# the slot_mapping's shape to determine the number of actual tokens.
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key[:num_actual_tokens],
value[:num_actual_tokens],
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,