Add attention sink in attention backends (#22320)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Woosuk Kwon
2025-08-05 22:37:21 -07:00
committed by GitHub
parent dd16bdc798
commit 6e20924350
7 changed files with 176 additions and 45 deletions

View File

@@ -28,6 +28,7 @@ def kernel_paged_attention_2d(
query_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs] seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads] alibi_slopes_ptr, # [num_query_heads]
@@ -95,7 +96,17 @@ def kernel_paged_attention_2d(
block_table_offset = seq_idx * block_table_stride block_table_offset = seq_idx * block_table_stride
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) if sink_ptr is None:
M = tl.full([num_queries_per_kv_padded],
float("-inf"),
dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_head_idx,
mask=head_mask,
other=float("-inf"),
).to(dtype=tl.float32)
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
dtype=tl.float32) dtype=tl.float32)
@@ -223,6 +234,8 @@ def chunked_prefill_paged_decode(
alibi_slopes=None, alibi_slopes=None,
sliding_window=None, sliding_window=None,
sm_scale=None, sm_scale=None,
# Optional tensor for sinks
sinks=None,
): ):
if sm_scale is None: if sm_scale is None:
@@ -253,6 +266,7 @@ def chunked_prefill_paged_decode(
sliding_window=sliding_window, sliding_window=sliding_window,
sm_scale=sm_scale, sm_scale=sm_scale,
skip_decode=True, skip_decode=True,
sinks=sinks,
) )
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
@@ -281,11 +295,17 @@ def chunked_prefill_paged_decode(
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
16) 16)
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, use_custom = use_rocm_custom_paged_attention(
block_size, query.dtype,
num_queries_per_kv, head_size,
max_seq_len, sliding_window, block_size,
kv_cache_dtype, alibi_slopes) num_queries_per_kv,
max_seq_len,
sliding_window,
kv_cache_dtype,
alibi_slopes,
sinks,
)
if use_custom: if use_custom:
_PARTITION_SIZE_ROCM = 256 _PARTITION_SIZE_ROCM = 256
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
@@ -334,6 +354,7 @@ def chunked_prefill_paged_decode(
query_ptr=query, query_ptr=query,
key_cache_ptr=key_cache, key_cache_ptr=key_cache,
value_cache_ptr=value_cache, value_cache_ptr=value_cache,
sink_ptr=sinks,
block_tables_ptr=block_table, block_tables_ptr=block_table,
seq_lens_ptr=seq_lens, seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes, alibi_slopes_ptr=alibi_slopes,

View File

@@ -38,6 +38,7 @@ def _fwd_kernel(Q,
V, V,
K_cache, K_cache,
V_cache, V_cache,
sink_ptr,
B_Loc, B_Loc,
sm_scale, sm_scale,
k_scale, k_scale,
@@ -126,7 +127,15 @@ def _fwd_kernel(Q,
other=0.0) # [M,D] other=0.0) # [M,D]
# initialize pointer to m and l # initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) if sink_ptr is None:
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
m_i = tl.load(
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
mask=(offs_m < cur_batch_query_len),
other=float("-inf"),
).to(dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
@@ -732,7 +741,8 @@ def context_attention_fwd(q,
alibi_slopes=None, alibi_slopes=None,
sliding_window=None, sliding_window=None,
sm_scale=None, sm_scale=None,
skip_decode=False): skip_decode=False,
sinks=None):
q_dtype_is_f32 = q.dtype is torch.float32 q_dtype_is_f32 = q.dtype is torch.float32
@@ -781,6 +791,7 @@ def context_attention_fwd(q,
sliding_window = 0 sliding_window = 0
if alibi_slopes is not None: if alibi_slopes is not None:
assert sinks is None, "Sinks arg is not supported with alibi"
# need to reduce num. blocks when using fp32 # need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory # due to increased use of GPU shared memory
# if q.dtype is torch.float32: # if q.dtype is torch.float32:
@@ -843,7 +854,7 @@ def context_attention_fwd(q,
max_seq_len = 0 if max_seq_len is None else max_seq_len max_seq_len = 0 if max_seq_len is None else max_seq_len
extra_kargs = {} extra_kargs = {}
if current_platform.is_rocm(): if current_platform.is_rocm():
extra_kargs = {"kpack": 2, "waves_per_eu": 2} extra_kargs = {"kpack": 1, "waves_per_eu": 2}
grid = lambda META: (batch, head, grid = lambda META: (batch, head,
triton.cdiv(max_input_len, META["BLOCK_M"])) triton.cdiv(max_input_len, META["BLOCK_M"]))
@@ -853,6 +864,7 @@ def context_attention_fwd(q,
v, v,
k_cache, k_cache,
v_cache, v_cache,
sinks,
b_loc, b_loc,
sm_scale, sm_scale,
k_scale, k_scale,

View File

@@ -52,6 +52,7 @@ def kernel_unified_attention_2d(
query_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs] seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads] alibi_slopes_ptr, # [num_query_heads]
@@ -131,7 +132,15 @@ def kernel_unified_attention_2d(
block_table_offset = seq_idx * block_table_stride block_table_offset = seq_idx * block_table_stride
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) if sink_ptr is None:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_offset_1,
mask=query_mask_1,
other=float("-inf"),
).to(dtype=tl.float32)
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
@@ -292,6 +301,7 @@ def kernel_unified_attention_3d(
query_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs] seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads] alibi_slopes_ptr, # [num_query_heads]
@@ -383,7 +393,15 @@ def kernel_unified_attention_3d(
block_table_offset = seq_idx * block_table_stride block_table_offset = seq_idx * block_table_stride
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) if sink_ptr is None or segm_idx != 0:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_offset_1,
mask=query_mask_1,
other=float("-inf"),
).to(dtype=tl.float32)
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
@@ -627,6 +645,8 @@ def unified_attention(
v_descale, v_descale,
alibi_slopes=None, alibi_slopes=None,
qq_bias=None, qq_bias=None,
# Optional tensor for sinks
sinks=None,
): ):
assert causal, "Only causal attention is supported" assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported" assert q_descale is None, "Q scales not supported"
@@ -635,6 +655,10 @@ def unified_attention(
assert q.element_size() >= 2 or block_size >= 32, \ assert q.element_size() >= 2 or block_size >= 32, \
"Block size must be at least 32 for fp8" "Block size must be at least 32 for fp8"
if sinks is not None:
assert sinks.shape[0] == q.shape[1], \
"Sinks must be num_query_heads size"
use_alibi_slopes = alibi_slopes is not None use_alibi_slopes = alibi_slopes is not None
use_qq_bias = qq_bias is not None use_qq_bias = qq_bias is not None
@@ -669,6 +693,7 @@ def unified_attention(
query_ptr=q, query_ptr=q,
key_cache_ptr=k, key_cache_ptr=k,
value_cache_ptr=v, value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table, block_tables_ptr=block_table,
seq_lens_ptr=seqused_k, seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes, alibi_slopes_ptr=alibi_slopes,
@@ -741,6 +766,7 @@ def unified_attention(
query_ptr=q, query_ptr=q,
key_cache_ptr=k, key_cache_ptr=k,
value_cache_ptr=v, value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table, block_tables_ptr=block_table,
seq_lens_ptr=seqused_k, seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes, alibi_slopes_ptr=alibi_slopes,

View File

@@ -17,6 +17,7 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_FLASH_ATTN_VERSION: Optional[int] = None VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
@@ -151,6 +152,8 @@ if TYPE_CHECKING:
VLLM_LOOPBACK_IP: str = "" VLLM_LOOPBACK_IP: str = ""
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False
VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False
def get_default_cache_root(): def get_default_cache_root():
@@ -326,6 +329,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in
("true", "1")), ("true", "1")),
# Use AITER triton unified attention for V1 attention
"VLLM_USE_AITER_UNIFIED_ATTENTION":
lambda:
(os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in
("true", "1")),
# Force vllm to use a specific flash-attention version (2 or 3), only valid # Force vllm to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend. # when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION": "VLLM_FLASH_ATTN_VERSION":
@@ -1022,9 +1031,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_CUDNN_PREFILL": "VLLM_USE_CUDNN_PREFILL":
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
# If set to 1, use the TRTLLM Attention backend in flashinfer. # If set to 1, use the TRTLLM Context Attention backend in flashinfer.
"VLLM_USE_TRTLLM_ATTENTION": "VLLM_USE_TRTLLM_CONTEXT_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_CONTEXT_ATTENTION", "0"))),
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", "0"))),
# Controls garbage collection during CUDA graph capture. # Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time. # If set to 0 (default), enables GC freezing to speed up capture time.

View File

@@ -373,6 +373,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None, kv_sharing_target_layer_name: Optional[str] = None,
sinks: Optional[torch.Tensor] = None,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
@@ -410,6 +411,14 @@ class FlashAttentionImpl(AttentionImpl):
raise NotImplementedError( raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.") "FlashAttention does not support fp8 kv-cache on this device.")
self.sinks = sinks
if self.sinks is not None:
assert self.vllm_flash_attn_version == 3, (
"Sinks are only supported in FlashAttention 3")
assert self.sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
"heads in the layer")
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@@ -534,6 +543,7 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits, num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
) )
return output return output

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill.""" """Attention layer with PagedAttention and Triton prefix prefill."""
from dataclasses import dataclass from dataclasses import dataclass
from functools import cache
from typing import ClassVar, Optional from typing import ClassVar, Optional
import torch import torch
@@ -13,7 +14,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.ops.chunked_prefill_paged_decode import ( from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode) chunked_prefill_paged_decode)
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
@@ -193,6 +193,15 @@ class TritonAttentionBackend(AttentionBackend):
return TritonAttentionMetadataBuilder return TritonAttentionMetadataBuilder
@cache
def use_aiter_unified_attention() -> bool:
"""Check if aiter unified attention should be used."""
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
# to 1 as default
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
class TritonAttentionImpl(AttentionImpl): class TritonAttentionImpl(AttentionImpl):
def __init__( def __init__(
@@ -207,6 +216,7 @@ class TritonAttentionImpl(AttentionImpl):
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None, kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
@@ -240,6 +250,29 @@ class TritonAttentionImpl(AttentionImpl):
self.force_prefill_decode_attn = \ self.force_prefill_decode_attn = \
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
if not self.force_prefill_decode_attn:
# If not using prefill decode attention, we use the Triton
# unified attention implementation.
if use_aiter_unified_attention():
logger.info_once(
"Using aiter unified attention for TritonAttentionImpl")
from aiter.ops.triton.unified_attention import (
unified_attention)
self.unified_attention = unified_attention
else:
logger.info_once(
"Using vllm unified attention for TritonAttentionImpl")
from vllm.attention.ops.triton_unified_attention import (
unified_attention)
self.unified_attention = unified_attention
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@@ -342,28 +375,31 @@ class TritonAttentionImpl(AttentionImpl):
if use_prefill_decode_attn: if use_prefill_decode_attn:
# Compute attention and update output up to `num_actual_tokens`. # Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(query=query[:num_actual_tokens], chunked_prefill_paged_decode(
key=key[:num_actual_tokens], query=query[:num_actual_tokens],
value=value[:num_actual_tokens], key=key[:num_actual_tokens],
output=output[:num_actual_tokens], value=value[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype, output=output[:num_actual_tokens],
key_cache=key_cache, kv_cache_dtype=self.kv_cache_dtype,
value_cache=value_cache, key_cache=key_cache,
block_table=block_table, value_cache=value_cache,
query_start_loc=cu_seqlens_q, block_table=block_table,
seq_lens=seqused_k, query_start_loc=cu_seqlens_q,
max_seq_len=max_seqlen_k, seq_lens=seqused_k,
max_query_len=max_seqlen_q, max_seq_len=max_seqlen_k,
k_scale=layer._k_scale, max_query_len=max_seqlen_q,
v_scale=layer._v_scale, k_scale=layer._k_scale,
alibi_slopes=self.alibi_slopes, v_scale=layer._v_scale,
sliding_window=self.sliding_window[0], alibi_slopes=self.alibi_slopes,
sm_scale=self.scale) sliding_window=self.sliding_window[0],
sm_scale=self.scale,
sinks=self.sinks,
)
else: else:
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
unified_attention( self.unified_attention(
q=query[:num_actual_tokens], q=query[:num_actual_tokens],
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
@@ -381,6 +417,7 @@ class TritonAttentionImpl(AttentionImpl):
q_descale=None, # Not supported q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
) )
return output return output

View File

@@ -254,7 +254,11 @@ def get_kv_cache_layout():
# Override with format specified by the user. # Override with format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT cache_layout = envs.VLLM_KV_CACHE_LAYOUT
if cache_layout is None: if cache_layout is None:
cache_layout = get_kv_connector_cache_layout() if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
cache_layout = "HND"
else:
cache_layout = get_kv_connector_cache_layout()
else: else:
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout) "detected. Setting KV cache layout to %s.", cache_layout)
@@ -272,7 +276,9 @@ def set_kv_cache_layout(cache_layout: str):
class PerLayerParameters: class PerLayerParameters:
""" """
Currently, FlashInfer backend only support models in which all layers share Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters. the same values for the following hyperparameters. Should not be used for
trtllm-gen backend since it supports different values for the following
hyperparameters.
""" """
window_left: int window_left: int
@@ -310,7 +316,8 @@ def get_per_layer_parameters(
def infer_global_hyperparameters( def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
""" """
Currently, FlashInfer backend only support models in which all layers share Currently, FlashInfer backend other than trtllm-gen
only support models in which all layers share
the same values for the following hyperparameters: the same values for the following hyperparameters:
- `window_left` - `window_left`
- `logits_soft_cap` - `logits_soft_cap`
@@ -324,15 +331,20 @@ def infer_global_hyperparameters(
param_sets = list(per_layer_params.values()) param_sets = list(per_layer_params.values())
global_params = param_sets[0] global_params = param_sets[0]
for params in param_sets:
if params.window_left != global_params.window_left: # trtllm attention doesn't need global hyper params so disable the check
raise ValueError( if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
"Window left is not the same for all layers. One potential fix " and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
"is to set disable_sliding_window=True") for params in param_sets:
assert params == global_params, ( if params.window_left != global_params.window_left:
"FlashInfer backend currently only supports models in which all " raise ValueError(
"layers share the same values for the following hyperparameters: " "Window left is not the same for all layers. " \
"`window_left`, `logits_soft_cap`, `sm_scale`.") "One potential fix is to set disable_sliding_window=True")
assert params == global_params, (
"FlashInfer backend currently only supports models in which all"
"layers share the same values "
"for the following hyperparameters:"
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params return global_params