[v1] Add encoder-only/cross attention support to Triton Attention backend (#31406)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-01-06 00:00:23 +08:00
committed by GitHub
parent 911d38ed99
commit 6aa5b18e1d
6 changed files with 627 additions and 14 deletions

View File

@@ -13,6 +13,7 @@ from vllm.attention.backends.abstract import (
AttentionType,
MultipleOf,
)
from vllm.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
@@ -309,6 +310,16 @@ class TritonAttentionBackend(AttentionBackend):
def supports_sink(cls) -> bool:
return True
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""TritonAttention supports all attention types."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
@@ -341,6 +352,8 @@ class TritonAttentionImpl(AttentionImpl):
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
elif attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY):
self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
@@ -352,10 +365,6 @@ class TritonAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention is not implemented for TritonAttentionImpl"
)
self.attn_type = attn_type
self.fp8_dtype = current_platform.fp8_dtype()
@@ -417,6 +426,21 @@ class TritonAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
layer,
)
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
if (
@@ -495,3 +519,48 @@ class TritonAttentionImpl(AttentionImpl):
)
return output
def _forward_encoder_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
layer: torch.nn.Module,
) -> torch.Tensor:
"""Forward pass for encoder attention without KV cache.
Args:
query: shape = [num_encoder_tokens, num_heads, head_size]
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
output: shape = [num_encoder_tokens, num_heads, head_size]
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"quantization is not supported for encoder attention"
)
# Use encoder-specific metadata for sequence information
query_start_loc = attn_metadata.query_start_loc
seq_lens = attn_metadata.seq_lens
max_query_len = attn_metadata.max_query_len
# Call flash attention directly on Q, K, V tensors
context_attention_fwd(
q=query,
k=key,
v=value,
o=output,
b_start_loc=query_start_loc,
b_seq_len=seq_lens,
max_input_len=max_query_len,
is_causal=False, # Encoder attention is bidirectional
sliding_window_q=self.sliding_window[0],
sliding_window_k=self.sliding_window[1],
)
return output