[Attention] Allow V1 flash_attn to support cross-attention (#23297)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -405,13 +405,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
FlashAttentionBackend.validate_head_size(head_size)
|
FlashAttentionBackend.validate_head_size(head_size)
|
||||||
|
|
||||||
if attn_type not in [
|
|
||||||
AttentionType.DECODER, AttentionType.ENCODER_ONLY
|
|
||||||
]:
|
|
||||||
raise NotImplementedError("Encoder/decoder cross-attention "
|
|
||||||
"is not implemented for "
|
|
||||||
"FlashAttentionImpl")
|
|
||||||
|
|
||||||
self.attn_type = attn_type
|
self.attn_type = attn_type
|
||||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||||
@@ -477,7 +470,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
# Handle encoder attention differently - no KV cache needed
|
# Handle encoder attention differently - no KV cache needed
|
||||||
if attn_type in (AttentionType.ENCODER_ONLY, ):
|
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||||
# For encoder attention,
|
# For encoder attention,
|
||||||
# we use direct Q, K, V tensors without caching
|
# we use direct Q, K, V tensors without caching
|
||||||
return self._forward_encoder_attention(query[:num_actual_tokens],
|
return self._forward_encoder_attention(query[:num_actual_tokens],
|
||||||
@@ -489,7 +482,11 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# For decoder and cross-attention, use KV cache as before
|
# For decoder and cross-attention, use KV cache as before
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
|
||||||
if self.kv_sharing_target_layer_name is None:
|
# key and value may be None in the case of cross attention. They are
|
||||||
|
# calculated once based on the output from the encoder and then cached
|
||||||
|
# in KV cache.
|
||||||
|
if (self.kv_sharing_target_layer_name is None and key is not None
|
||||||
|
and value is not None):
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||||
@@ -528,7 +525,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
block_table = attn_metadata.block_table
|
block_table = attn_metadata.block_table
|
||||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||||
|
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||||
|
|
||||||
flash_attn_varlen_func(
|
flash_attn_varlen_func(
|
||||||
q=query[:num_actual_tokens],
|
q=query[:num_actual_tokens],
|
||||||
|
|||||||
Reference in New Issue
Block a user