[V1] Further reduce CPU overheads in flash-attn (#10989)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -138,14 +138,25 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens] and
|
||||
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
|
||||
# the slot_mapping's shape to determine the number of actual tokens.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
|
||||
Reference in New Issue
Block a user