Support encoder-only models without KV-Cache (#21270)

Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Maximilien de Bayser
2025-07-26 10:09:52 -03:00
committed by GitHub
parent f27fdfc3ed
commit 1cd6eaba54
17 changed files with 352 additions and 99 deletions

View File

@@ -130,6 +130,8 @@ class FlashAttentionMetadata:
prefix_scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0
causal: bool = True
def _get_sliding_window_configs(
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
@@ -213,6 +215,7 @@ class FlashAttentionMetadataBuilder(
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
# the overhead of the aot schedule is not worth it for spec-decode
aot_schedule = self.aot_schedule and not fast_build
@@ -288,7 +291,7 @@ class FlashAttentionMetadataBuilder(
max_query_len=max_query_len,
seqlens=seq_lens,
max_seq_len=max_seq_len,
causal=True)
causal=causal)
if self.use_full_cuda_graph:
assert scheduler_metadata is not None
@@ -326,7 +329,7 @@ class FlashAttentionMetadataBuilder(
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
max_num_splits=max_num_splits,
)
causal=causal)
return attn_metadata
def can_run_in_cudagraph(
@@ -375,11 +378,14 @@ class FlashAttentionImpl(AttentionImpl):
FlashAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/decoder cross-attention "
"is not implemented for "
"FlashAttentionImpl")
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
@@ -422,6 +428,8 @@ class FlashAttentionImpl(AttentionImpl):
# Profiling run.
return output
attn_type = self.attn_type
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@@ -432,6 +440,18 @@ class FlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if attn_type in (AttentionType.ENCODER_ONLY, ):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata, layer)
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
@@ -483,7 +503,7 @@ class FlashAttentionImpl(AttentionImpl):
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
@@ -524,6 +544,63 @@ class FlashAttentionImpl(AttentionImpl):
)
return output
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"quantization is not supported for encoder attention")
# Use encoder-specific metadata for sequence information
cu_seqlens_q = attn_metadata.query_start_loc
cu_seqlens_k = attn_metadata.query_start_loc
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_query_len
descale_shape = (
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
self.num_kv_heads)
# Call flash attention directly on Q, K, V tensors
flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=output,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=False, # Encoder attention is bidirectional
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output
def use_cascade_attention(
common_prefix_len: int,

View File

@@ -59,6 +59,8 @@ class CommonAttentionMetadata:
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
causal: bool = True
M = TypeVar("M")
@@ -395,6 +397,7 @@ def make_local_attention_virtual_batches(
max_query_len=seqlens_q_local.max(),
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
)