Add attention sink in attention backends (#22320)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Woosuk Kwon
2025-08-05 22:37:21 -07:00
committed by GitHub
parent dd16bdc798
commit 6e20924350
7 changed files with 176 additions and 45 deletions

View File

@@ -254,7 +254,11 @@ def get_kv_cache_layout():
# Override with format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
if cache_layout is None:
cache_layout = get_kv_connector_cache_layout()
if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
cache_layout = "HND"
else:
cache_layout = get_kv_connector_cache_layout()
else:
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout)
@@ -272,7 +276,9 @@ def set_kv_cache_layout(cache_layout: str):
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
the same values for the following hyperparameters. Should not be used for
trtllm-gen backend since it supports different values for the following
hyperparameters.
"""
window_left: int
@@ -310,7 +316,8 @@ def get_per_layer_parameters(
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
Currently, FlashInfer backend other than trtllm-gen
only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
@@ -324,15 +331,20 @@ def infer_global_hyperparameters(
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(
"Window left is not the same for all layers. One potential fix "
"is to set disable_sliding_window=True")
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")
# trtllm attention doesn't need global hyper params so disable the check
if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(
"Window left is not the same for all layers. " \
"One potential fix is to set disable_sliding_window=True")
assert params == global_params, (
"FlashInfer backend currently only supports models in which all"
"layers share the same values "
"for the following hyperparameters:"
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params