Add attention sink in attention backends (#22320)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Woosuk Kwon
2025-08-05 22:37:21 -07:00
committed by GitHub
parent dd16bdc798
commit 6e20924350
7 changed files with 176 additions and 45 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
from dataclasses import dataclass
from functools import cache
from typing import ClassVar, Optional
import torch
@@ -13,7 +14,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -193,6 +193,15 @@ class TritonAttentionBackend(AttentionBackend):
return TritonAttentionMetadataBuilder
@cache
def use_aiter_unified_attention() -> bool:
"""Check if aiter unified attention should be used."""
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
# to 1 as default
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
class TritonAttentionImpl(AttentionImpl):
def __init__(
@@ -207,6 +216,7 @@ class TritonAttentionImpl(AttentionImpl):
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@@ -240,6 +250,29 @@ class TritonAttentionImpl(AttentionImpl):
self.force_prefill_decode_attn = \
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
if not self.force_prefill_decode_attn:
# If not using prefill decode attention, we use the Triton
# unified attention implementation.
if use_aiter_unified_attention():
logger.info_once(
"Using aiter unified attention for TritonAttentionImpl")
from aiter.ops.triton.unified_attention import (
unified_attention)
self.unified_attention = unified_attention
else:
logger.info_once(
"Using vllm unified attention for TritonAttentionImpl")
from vllm.attention.ops.triton_unified_attention import (
unified_attention)
self.unified_attention = unified_attention
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
def forward(
self,
layer: torch.nn.Module,
@@ -342,28 +375,31 @@ class TritonAttentionImpl(AttentionImpl):
if use_prefill_decode_attn:
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale)
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
sinks=self.sinks,
)
else:
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
unified_attention(
self.unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
@@ -381,6 +417,7 @@ class TritonAttentionImpl(AttentionImpl):
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
)
return output