Support encoder-only models without KV-Cache (#21270)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
committed by
GitHub
parent
f27fdfc3ed
commit
1cd6eaba54
@@ -130,6 +130,8 @@ class FlashAttentionMetadata:
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
causal: bool = True
|
||||
|
||||
|
||||
def _get_sliding_window_configs(
|
||||
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
||||
@@ -213,6 +215,7 @@ class FlashAttentionMetadataBuilder(
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = common_attn_metadata.causal
|
||||
|
||||
# the overhead of the aot schedule is not worth it for spec-decode
|
||||
aot_schedule = self.aot_schedule and not fast_build
|
||||
@@ -288,7 +291,7 @@ class FlashAttentionMetadataBuilder(
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True)
|
||||
causal=causal)
|
||||
|
||||
if self.use_full_cuda_graph:
|
||||
assert scheduler_metadata is not None
|
||||
@@ -326,7 +329,7 @@ class FlashAttentionMetadataBuilder(
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
)
|
||||
causal=causal)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
@@ -375,11 +378,14 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
FlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
if attn_type not in [
|
||||
AttentionType.DECODER, AttentionType.ENCODER_ONLY
|
||||
]:
|
||||
raise NotImplementedError("Encoder/decoder cross-attention "
|
||||
"is not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||
and not flash_attn_supports_fp8():
|
||||
@@ -422,6 +428,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
attn_type = self.attn_type
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
@@ -432,6 +440,18 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if attn_type in (AttentionType.ENCODER_ONLY, ):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return self._forward_encoder_attention(query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata, layer)
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
@@ -483,7 +503,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
@@ -524,6 +544,63 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
return output
|
||||
|
||||
def _forward_encoder_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for encoder attention without KV cache.
|
||||
|
||||
Args:
|
||||
query: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
||||
output: shape = [num_encoder_tokens, num_heads, head_size]
|
||||
attn_metadata: Encoder attention metadata
|
||||
layer: The attention layer
|
||||
"""
|
||||
# For encoder attention, process FP8 quantization if needed
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError(
|
||||
"quantization is not supported for encoder attention")
|
||||
|
||||
# Use encoder-specific metadata for sequence information
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
cu_seqlens_k = attn_metadata.query_start_loc
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_query_len
|
||||
|
||||
descale_shape = (
|
||||
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
|
||||
self.num_kv_heads)
|
||||
|
||||
# Call flash attention directly on Q, K, V tensors
|
||||
flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
out=output,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # Encoder attention is bidirectional
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def use_cascade_attention(
|
||||
common_prefix_len: int,
|
||||
|
||||
@@ -59,6 +59,8 @@ class CommonAttentionMetadata:
|
||||
block_table_tensor: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
causal: bool = True
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
@@ -395,6 +397,7 @@ def make_local_attention_virtual_batches(
|
||||
max_query_len=seqlens_q_local.max(),
|
||||
block_table_tensor=block_table_local,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user