Add attention sink in attention backends (#22320)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
@@ -373,6 +373,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@@ -410,6 +411,14 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
raise NotImplementedError(
|
||||
"FlashAttention does not support fp8 kv-cache on this device.")
|
||||
|
||||
self.sinks = sinks
|
||||
if self.sinks is not None:
|
||||
assert self.vllm_flash_attn_version == 3, (
|
||||
"Sinks are only supported in FlashAttention 3")
|
||||
assert self.sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
"heads in the layer")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -534,6 +543,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
@@ -13,7 +14,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -193,6 +193,15 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
return TritonAttentionMetadataBuilder
|
||||
|
||||
|
||||
@cache
|
||||
def use_aiter_unified_attention() -> bool:
|
||||
"""Check if aiter unified attention should be used."""
|
||||
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
|
||||
# to 1 as default
|
||||
return envs.VLLM_ROCM_USE_AITER \
|
||||
and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
@@ -207,6 +216,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@@ -240,6 +250,29 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
self.force_prefill_decode_attn = \
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
|
||||
if not self.force_prefill_decode_attn:
|
||||
# If not using prefill decode attention, we use the Triton
|
||||
# unified attention implementation.
|
||||
if use_aiter_unified_attention():
|
||||
logger.info_once(
|
||||
"Using aiter unified attention for TritonAttentionImpl")
|
||||
from aiter.ops.triton.unified_attention import (
|
||||
unified_attention)
|
||||
self.unified_attention = unified_attention
|
||||
else:
|
||||
logger.info_once(
|
||||
"Using vllm unified attention for TritonAttentionImpl")
|
||||
from vllm.attention.ops.triton_unified_attention import (
|
||||
unified_attention)
|
||||
self.unified_attention = unified_attention
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -342,28 +375,31 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
if use_prefill_decode_attn:
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale)
|
||||
chunked_prefill_paged_decode(
|
||||
query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale,
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
else:
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
unified_attention(
|
||||
self.unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
@@ -381,6 +417,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -254,7 +254,11 @@ def get_kv_cache_layout():
|
||||
# Override with format specified by the user.
|
||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||
if cache_layout is None:
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
|
||||
cache_layout = "HND"
|
||||
else:
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
else:
|
||||
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||
@@ -272,7 +276,9 @@ def set_kv_cache_layout(cache_layout: str):
|
||||
class PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters.
|
||||
the same values for the following hyperparameters. Should not be used for
|
||||
trtllm-gen backend since it supports different values for the following
|
||||
hyperparameters.
|
||||
"""
|
||||
|
||||
window_left: int
|
||||
@@ -310,7 +316,8 @@ def get_per_layer_parameters(
|
||||
def infer_global_hyperparameters(
|
||||
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
Currently, FlashInfer backend other than trtllm-gen
|
||||
only support models in which all layers share
|
||||
the same values for the following hyperparameters:
|
||||
- `window_left`
|
||||
- `logits_soft_cap`
|
||||
@@ -324,15 +331,20 @@ def infer_global_hyperparameters(
|
||||
|
||||
param_sets = list(per_layer_params.values())
|
||||
global_params = param_sets[0]
|
||||
for params in param_sets:
|
||||
if params.window_left != global_params.window_left:
|
||||
raise ValueError(
|
||||
"Window left is not the same for all layers. One potential fix "
|
||||
"is to set disable_sliding_window=True")
|
||||
assert params == global_params, (
|
||||
"FlashInfer backend currently only supports models in which all "
|
||||
"layers share the same values for the following hyperparameters: "
|
||||
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
||||
|
||||
# trtllm attention doesn't need global hyper params so disable the check
|
||||
if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||
and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
|
||||
for params in param_sets:
|
||||
if params.window_left != global_params.window_left:
|
||||
raise ValueError(
|
||||
"Window left is not the same for all layers. " \
|
||||
"One potential fix is to set disable_sliding_window=True")
|
||||
assert params == global_params, (
|
||||
"FlashInfer backend currently only supports models in which all"
|
||||
"layers share the same values "
|
||||
"for the following hyperparameters:"
|
||||
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
||||
|
||||
return global_params
|
||||
|
||||
|
||||
Reference in New Issue
Block a user