Add attention sink in attention backends (#22320)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
@@ -28,6 +28,7 @@ def kernel_paged_attention_2d(
|
|||||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||||
|
sink_ptr, # [num_query_heads]
|
||||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
seq_lens_ptr, # [num_seqs]
|
seq_lens_ptr, # [num_seqs]
|
||||||
alibi_slopes_ptr, # [num_query_heads]
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
@@ -95,7 +96,17 @@ def kernel_paged_attention_2d(
|
|||||||
|
|
||||||
block_table_offset = seq_idx * block_table_stride
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
|
if sink_ptr is None:
|
||||||
|
M = tl.full([num_queries_per_kv_padded],
|
||||||
|
float("-inf"),
|
||||||
|
dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.load(
|
||||||
|
sink_ptr + query_head_idx,
|
||||||
|
mask=head_mask,
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
|
||||||
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
|
||||||
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
|
||||||
dtype=tl.float32)
|
dtype=tl.float32)
|
||||||
@@ -223,6 +234,8 @@ def chunked_prefill_paged_decode(
|
|||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
sliding_window=None,
|
sliding_window=None,
|
||||||
sm_scale=None,
|
sm_scale=None,
|
||||||
|
# Optional tensor for sinks
|
||||||
|
sinks=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if sm_scale is None:
|
if sm_scale is None:
|
||||||
@@ -253,6 +266,7 @@ def chunked_prefill_paged_decode(
|
|||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
skip_decode=True,
|
skip_decode=True,
|
||||||
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
@@ -281,11 +295,17 @@ def chunked_prefill_paged_decode(
|
|||||||
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
|
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
|
||||||
16)
|
16)
|
||||||
|
|
||||||
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
|
use_custom = use_rocm_custom_paged_attention(
|
||||||
block_size,
|
query.dtype,
|
||||||
num_queries_per_kv,
|
head_size,
|
||||||
max_seq_len, sliding_window,
|
block_size,
|
||||||
kv_cache_dtype, alibi_slopes)
|
num_queries_per_kv,
|
||||||
|
max_seq_len,
|
||||||
|
sliding_window,
|
||||||
|
kv_cache_dtype,
|
||||||
|
alibi_slopes,
|
||||||
|
sinks,
|
||||||
|
)
|
||||||
if use_custom:
|
if use_custom:
|
||||||
_PARTITION_SIZE_ROCM = 256
|
_PARTITION_SIZE_ROCM = 256
|
||||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||||
@@ -334,6 +354,7 @@ def chunked_prefill_paged_decode(
|
|||||||
query_ptr=query,
|
query_ptr=query,
|
||||||
key_cache_ptr=key_cache,
|
key_cache_ptr=key_cache,
|
||||||
value_cache_ptr=value_cache,
|
value_cache_ptr=value_cache,
|
||||||
|
sink_ptr=sinks,
|
||||||
block_tables_ptr=block_table,
|
block_tables_ptr=block_table,
|
||||||
seq_lens_ptr=seq_lens,
|
seq_lens_ptr=seq_lens,
|
||||||
alibi_slopes_ptr=alibi_slopes,
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ def _fwd_kernel(Q,
|
|||||||
V,
|
V,
|
||||||
K_cache,
|
K_cache,
|
||||||
V_cache,
|
V_cache,
|
||||||
|
sink_ptr,
|
||||||
B_Loc,
|
B_Loc,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
k_scale,
|
k_scale,
|
||||||
@@ -126,7 +127,15 @@ def _fwd_kernel(Q,
|
|||||||
other=0.0) # [M,D]
|
other=0.0) # [M,D]
|
||||||
|
|
||||||
# initialize pointer to m and l
|
# initialize pointer to m and l
|
||||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
if sink_ptr is None:
|
||||||
|
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
m_i = tl.load(
|
||||||
|
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
|
||||||
|
mask=(offs_m < cur_batch_query_len),
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
|
||||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
|
||||||
|
|
||||||
@@ -732,7 +741,8 @@ def context_attention_fwd(q,
|
|||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
sliding_window=None,
|
sliding_window=None,
|
||||||
sm_scale=None,
|
sm_scale=None,
|
||||||
skip_decode=False):
|
skip_decode=False,
|
||||||
|
sinks=None):
|
||||||
|
|
||||||
q_dtype_is_f32 = q.dtype is torch.float32
|
q_dtype_is_f32 = q.dtype is torch.float32
|
||||||
|
|
||||||
@@ -781,6 +791,7 @@ def context_attention_fwd(q,
|
|||||||
sliding_window = 0
|
sliding_window = 0
|
||||||
|
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
|
assert sinks is None, "Sinks arg is not supported with alibi"
|
||||||
# need to reduce num. blocks when using fp32
|
# need to reduce num. blocks when using fp32
|
||||||
# due to increased use of GPU shared memory
|
# due to increased use of GPU shared memory
|
||||||
# if q.dtype is torch.float32:
|
# if q.dtype is torch.float32:
|
||||||
@@ -843,7 +854,7 @@ def context_attention_fwd(q,
|
|||||||
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
extra_kargs = {"kpack": 2, "waves_per_eu": 2}
|
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
|
||||||
|
|
||||||
grid = lambda META: (batch, head,
|
grid = lambda META: (batch, head,
|
||||||
triton.cdiv(max_input_len, META["BLOCK_M"]))
|
triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||||
@@ -853,6 +864,7 @@ def context_attention_fwd(q,
|
|||||||
v,
|
v,
|
||||||
k_cache,
|
k_cache,
|
||||||
v_cache,
|
v_cache,
|
||||||
|
sinks,
|
||||||
b_loc,
|
b_loc,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
k_scale,
|
k_scale,
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ def kernel_unified_attention_2d(
|
|||||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||||
|
sink_ptr, # [num_query_heads]
|
||||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
seq_lens_ptr, # [num_seqs]
|
seq_lens_ptr, # [num_seqs]
|
||||||
alibi_slopes_ptr, # [num_query_heads]
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
@@ -131,7 +132,15 @@ def kernel_unified_attention_2d(
|
|||||||
|
|
||||||
block_table_offset = seq_idx * block_table_stride
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
if sink_ptr is None:
|
||||||
|
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.load(
|
||||||
|
sink_ptr + query_offset_1,
|
||||||
|
mask=query_mask_1,
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
|
||||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
@@ -292,6 +301,7 @@ def kernel_unified_attention_3d(
|
|||||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||||
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
|
||||||
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
|
||||||
|
sink_ptr, # [num_query_heads]
|
||||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||||
seq_lens_ptr, # [num_seqs]
|
seq_lens_ptr, # [num_seqs]
|
||||||
alibi_slopes_ptr, # [num_query_heads]
|
alibi_slopes_ptr, # [num_query_heads]
|
||||||
@@ -383,7 +393,15 @@ def kernel_unified_attention_3d(
|
|||||||
|
|
||||||
block_table_offset = seq_idx * block_table_stride
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
if sink_ptr is None or segm_idx != 0:
|
||||||
|
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
else:
|
||||||
|
M = tl.load(
|
||||||
|
sink_ptr + query_offset_1,
|
||||||
|
mask=query_mask_1,
|
||||||
|
other=float("-inf"),
|
||||||
|
).to(dtype=tl.float32)
|
||||||
|
|
||||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||||
|
|
||||||
@@ -627,6 +645,8 @@ def unified_attention(
|
|||||||
v_descale,
|
v_descale,
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
qq_bias=None,
|
qq_bias=None,
|
||||||
|
# Optional tensor for sinks
|
||||||
|
sinks=None,
|
||||||
):
|
):
|
||||||
assert causal, "Only causal attention is supported"
|
assert causal, "Only causal attention is supported"
|
||||||
assert q_descale is None, "Q scales not supported"
|
assert q_descale is None, "Q scales not supported"
|
||||||
@@ -635,6 +655,10 @@ def unified_attention(
|
|||||||
assert q.element_size() >= 2 or block_size >= 32, \
|
assert q.element_size() >= 2 or block_size >= 32, \
|
||||||
"Block size must be at least 32 for fp8"
|
"Block size must be at least 32 for fp8"
|
||||||
|
|
||||||
|
if sinks is not None:
|
||||||
|
assert sinks.shape[0] == q.shape[1], \
|
||||||
|
"Sinks must be num_query_heads size"
|
||||||
|
|
||||||
use_alibi_slopes = alibi_slopes is not None
|
use_alibi_slopes = alibi_slopes is not None
|
||||||
use_qq_bias = qq_bias is not None
|
use_qq_bias = qq_bias is not None
|
||||||
|
|
||||||
@@ -669,6 +693,7 @@ def unified_attention(
|
|||||||
query_ptr=q,
|
query_ptr=q,
|
||||||
key_cache_ptr=k,
|
key_cache_ptr=k,
|
||||||
value_cache_ptr=v,
|
value_cache_ptr=v,
|
||||||
|
sink_ptr=sinks,
|
||||||
block_tables_ptr=block_table,
|
block_tables_ptr=block_table,
|
||||||
seq_lens_ptr=seqused_k,
|
seq_lens_ptr=seqused_k,
|
||||||
alibi_slopes_ptr=alibi_slopes,
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
@@ -741,6 +766,7 @@ def unified_attention(
|
|||||||
query_ptr=q,
|
query_ptr=q,
|
||||||
key_cache_ptr=k,
|
key_cache_ptr=k,
|
||||||
value_cache_ptr=v,
|
value_cache_ptr=v,
|
||||||
|
sink_ptr=sinks,
|
||||||
block_tables_ptr=block_table,
|
block_tables_ptr=block_table,
|
||||||
seq_lens_ptr=seqused_k,
|
seq_lens_ptr=seqused_k,
|
||||||
alibi_slopes_ptr=alibi_slopes,
|
alibi_slopes_ptr=alibi_slopes,
|
||||||
|
|||||||
19
vllm/envs.py
19
vllm/envs.py
@@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
|||||||
LD_LIBRARY_PATH: Optional[str] = None
|
LD_LIBRARY_PATH: Optional[str] = None
|
||||||
VLLM_USE_TRITON_FLASH_ATTN: bool = True
|
VLLM_USE_TRITON_FLASH_ATTN: bool = True
|
||||||
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
|
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
|
||||||
|
VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False
|
||||||
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
|
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
|
||||||
LOCAL_RANK: int = 0
|
LOCAL_RANK: int = 0
|
||||||
CUDA_VISIBLE_DEVICES: Optional[str] = None
|
CUDA_VISIBLE_DEVICES: Optional[str] = None
|
||||||
@@ -151,6 +152,8 @@ if TYPE_CHECKING:
|
|||||||
VLLM_LOOPBACK_IP: str = ""
|
VLLM_LOOPBACK_IP: str = ""
|
||||||
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
|
||||||
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
||||||
|
VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False
|
||||||
|
VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@@ -326,6 +329,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in
|
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in
|
||||||
("true", "1")),
|
("true", "1")),
|
||||||
|
|
||||||
|
# Use AITER triton unified attention for V1 attention
|
||||||
|
"VLLM_USE_AITER_UNIFIED_ATTENTION":
|
||||||
|
lambda:
|
||||||
|
(os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in
|
||||||
|
("true", "1")),
|
||||||
|
|
||||||
# Force vllm to use a specific flash-attention version (2 or 3), only valid
|
# Force vllm to use a specific flash-attention version (2 or 3), only valid
|
||||||
# when using the flash-attention backend.
|
# when using the flash-attention backend.
|
||||||
"VLLM_FLASH_ATTN_VERSION":
|
"VLLM_FLASH_ATTN_VERSION":
|
||||||
@@ -1022,9 +1031,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_USE_CUDNN_PREFILL":
|
"VLLM_USE_CUDNN_PREFILL":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
|
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
|
||||||
|
|
||||||
# If set to 1, use the TRTLLM Attention backend in flashinfer.
|
# If set to 1, use the TRTLLM Context Attention backend in flashinfer.
|
||||||
"VLLM_USE_TRTLLM_ATTENTION":
|
"VLLM_USE_TRTLLM_CONTEXT_ATTENTION":
|
||||||
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
|
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_CONTEXT_ATTENTION", "0"))),
|
||||||
|
|
||||||
|
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
|
||||||
|
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", "0"))),
|
||||||
|
|
||||||
# Controls garbage collection during CUDA graph capture.
|
# Controls garbage collection during CUDA graph capture.
|
||||||
# If set to 0 (default), enables GC freezing to speed up capture time.
|
# If set to 0 (default), enables GC freezing to speed up capture time.
|
||||||
|
|||||||
@@ -373,6 +373,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: Optional[str] = None,
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
|
sinks: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@@ -410,6 +411,14 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FlashAttention does not support fp8 kv-cache on this device.")
|
"FlashAttention does not support fp8 kv-cache on this device.")
|
||||||
|
|
||||||
|
self.sinks = sinks
|
||||||
|
if self.sinks is not None:
|
||||||
|
assert self.vllm_flash_attn_version == 3, (
|
||||||
|
"Sinks are only supported in FlashAttention 3")
|
||||||
|
assert self.sinks.shape[0] == num_heads, (
|
||||||
|
"Sinks must have the same number of heads as the number of "
|
||||||
|
"heads in the layer")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -534,6 +543,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
k_descale=layer._k_scale.expand(descale_shape),
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
v_descale=layer._v_scale.expand(descale_shape),
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
num_splits=attn_metadata.max_num_splits,
|
num_splits=attn_metadata.max_num_splits,
|
||||||
|
s_aux=self.sinks,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import cache
|
||||||
from typing import ClassVar, Optional
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -13,7 +14,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||||
chunked_prefill_paged_decode)
|
chunked_prefill_paged_decode)
|
||||||
from vllm.attention.ops.paged_attn import PagedAttention
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -193,6 +193,15 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
return TritonAttentionMetadataBuilder
|
return TritonAttentionMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def use_aiter_unified_attention() -> bool:
|
||||||
|
"""Check if aiter unified attention should be used."""
|
||||||
|
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
|
||||||
|
# to 1 as default
|
||||||
|
return envs.VLLM_ROCM_USE_AITER \
|
||||||
|
and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
|
||||||
|
|
||||||
|
|
||||||
class TritonAttentionImpl(AttentionImpl):
|
class TritonAttentionImpl(AttentionImpl):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -207,6 +216,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: Optional[int] = None,
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
|
sinks: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@@ -240,6 +250,29 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
self.force_prefill_decode_attn = \
|
self.force_prefill_decode_attn = \
|
||||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||||
|
|
||||||
|
if not self.force_prefill_decode_attn:
|
||||||
|
# If not using prefill decode attention, we use the Triton
|
||||||
|
# unified attention implementation.
|
||||||
|
if use_aiter_unified_attention():
|
||||||
|
logger.info_once(
|
||||||
|
"Using aiter unified attention for TritonAttentionImpl")
|
||||||
|
from aiter.ops.triton.unified_attention import (
|
||||||
|
unified_attention)
|
||||||
|
self.unified_attention = unified_attention
|
||||||
|
else:
|
||||||
|
logger.info_once(
|
||||||
|
"Using vllm unified attention for TritonAttentionImpl")
|
||||||
|
from vllm.attention.ops.triton_unified_attention import (
|
||||||
|
unified_attention)
|
||||||
|
self.unified_attention = unified_attention
|
||||||
|
|
||||||
|
self.sinks = sinks
|
||||||
|
if sinks is not None:
|
||||||
|
assert sinks.shape[0] == num_heads, (
|
||||||
|
"Sinks must have the same number of heads as the number of "
|
||||||
|
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||||
|
f"num_heads: {num_heads}.")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -342,28 +375,31 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if use_prefill_decode_attn:
|
if use_prefill_decode_attn:
|
||||||
# Compute attention and update output up to `num_actual_tokens`.
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
|
chunked_prefill_paged_decode(
|
||||||
key=key[:num_actual_tokens],
|
query=query[:num_actual_tokens],
|
||||||
value=value[:num_actual_tokens],
|
key=key[:num_actual_tokens],
|
||||||
output=output[:num_actual_tokens],
|
value=value[:num_actual_tokens],
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
output=output[:num_actual_tokens],
|
||||||
key_cache=key_cache,
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
value_cache=value_cache,
|
key_cache=key_cache,
|
||||||
block_table=block_table,
|
value_cache=value_cache,
|
||||||
query_start_loc=cu_seqlens_q,
|
block_table=block_table,
|
||||||
seq_lens=seqused_k,
|
query_start_loc=cu_seqlens_q,
|
||||||
max_seq_len=max_seqlen_k,
|
seq_lens=seqused_k,
|
||||||
max_query_len=max_seqlen_q,
|
max_seq_len=max_seqlen_k,
|
||||||
k_scale=layer._k_scale,
|
max_query_len=max_seqlen_q,
|
||||||
v_scale=layer._v_scale,
|
k_scale=layer._k_scale,
|
||||||
alibi_slopes=self.alibi_slopes,
|
v_scale=layer._v_scale,
|
||||||
sliding_window=self.sliding_window[0],
|
alibi_slopes=self.alibi_slopes,
|
||||||
sm_scale=self.scale)
|
sliding_window=self.sliding_window[0],
|
||||||
|
sm_scale=self.scale,
|
||||||
|
sinks=self.sinks,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||||
|
|
||||||
unified_attention(
|
self.unified_attention(
|
||||||
q=query[:num_actual_tokens],
|
q=query[:num_actual_tokens],
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
@@ -381,6 +417,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
q_descale=None, # Not supported
|
q_descale=None, # Not supported
|
||||||
k_descale=layer._k_scale.expand(descale_shape),
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
v_descale=layer._v_scale.expand(descale_shape),
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
sinks=self.sinks,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -254,7 +254,11 @@ def get_kv_cache_layout():
|
|||||||
# Override with format specified by the user.
|
# Override with format specified by the user.
|
||||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||||
if cache_layout is None:
|
if cache_layout is None:
|
||||||
cache_layout = get_kv_connector_cache_layout()
|
if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||||
|
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
|
||||||
|
cache_layout = "HND"
|
||||||
|
else:
|
||||||
|
cache_layout = get_kv_connector_cache_layout()
|
||||||
else:
|
else:
|
||||||
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
||||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||||
@@ -272,7 +276,9 @@ def set_kv_cache_layout(cache_layout: str):
|
|||||||
class PerLayerParameters:
|
class PerLayerParameters:
|
||||||
"""
|
"""
|
||||||
Currently, FlashInfer backend only support models in which all layers share
|
Currently, FlashInfer backend only support models in which all layers share
|
||||||
the same values for the following hyperparameters.
|
the same values for the following hyperparameters. Should not be used for
|
||||||
|
trtllm-gen backend since it supports different values for the following
|
||||||
|
hyperparameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
window_left: int
|
window_left: int
|
||||||
@@ -310,7 +316,8 @@ def get_per_layer_parameters(
|
|||||||
def infer_global_hyperparameters(
|
def infer_global_hyperparameters(
|
||||||
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
|
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
|
||||||
"""
|
"""
|
||||||
Currently, FlashInfer backend only support models in which all layers share
|
Currently, FlashInfer backend other than trtllm-gen
|
||||||
|
only support models in which all layers share
|
||||||
the same values for the following hyperparameters:
|
the same values for the following hyperparameters:
|
||||||
- `window_left`
|
- `window_left`
|
||||||
- `logits_soft_cap`
|
- `logits_soft_cap`
|
||||||
@@ -324,15 +331,20 @@ def infer_global_hyperparameters(
|
|||||||
|
|
||||||
param_sets = list(per_layer_params.values())
|
param_sets = list(per_layer_params.values())
|
||||||
global_params = param_sets[0]
|
global_params = param_sets[0]
|
||||||
for params in param_sets:
|
|
||||||
if params.window_left != global_params.window_left:
|
# trtllm attention doesn't need global hyper params so disable the check
|
||||||
raise ValueError(
|
if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||||
"Window left is not the same for all layers. One potential fix "
|
and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
|
||||||
"is to set disable_sliding_window=True")
|
for params in param_sets:
|
||||||
assert params == global_params, (
|
if params.window_left != global_params.window_left:
|
||||||
"FlashInfer backend currently only supports models in which all "
|
raise ValueError(
|
||||||
"layers share the same values for the following hyperparameters: "
|
"Window left is not the same for all layers. " \
|
||||||
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
"One potential fix is to set disable_sliding_window=True")
|
||||||
|
assert params == global_params, (
|
||||||
|
"FlashInfer backend currently only supports models in which all"
|
||||||
|
"layers share the same values "
|
||||||
|
"for the following hyperparameters:"
|
||||||
|
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
||||||
|
|
||||||
return global_params
|
return global_params
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user