Add attention sink in attention backends (#22320)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
@@ -254,7 +254,11 @@ def get_kv_cache_layout():
|
||||
# Override with format specified by the user.
|
||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||
if cache_layout is None:
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
|
||||
cache_layout = "HND"
|
||||
else:
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
else:
|
||||
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||
@@ -272,7 +276,9 @@ def set_kv_cache_layout(cache_layout: str):
|
||||
class PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters.
|
||||
the same values for the following hyperparameters. Should not be used for
|
||||
trtllm-gen backend since it supports different values for the following
|
||||
hyperparameters.
|
||||
"""
|
||||
|
||||
window_left: int
|
||||
@@ -310,7 +316,8 @@ def get_per_layer_parameters(
|
||||
def infer_global_hyperparameters(
|
||||
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
Currently, FlashInfer backend other than trtllm-gen
|
||||
only support models in which all layers share
|
||||
the same values for the following hyperparameters:
|
||||
- `window_left`
|
||||
- `logits_soft_cap`
|
||||
@@ -324,15 +331,20 @@ def infer_global_hyperparameters(
|
||||
|
||||
param_sets = list(per_layer_params.values())
|
||||
global_params = param_sets[0]
|
||||
for params in param_sets:
|
||||
if params.window_left != global_params.window_left:
|
||||
raise ValueError(
|
||||
"Window left is not the same for all layers. One potential fix "
|
||||
"is to set disable_sliding_window=True")
|
||||
assert params == global_params, (
|
||||
"FlashInfer backend currently only supports models in which all "
|
||||
"layers share the same values for the following hyperparameters: "
|
||||
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
||||
|
||||
# trtllm attention doesn't need global hyper params so disable the check
|
||||
if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||
and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
|
||||
for params in param_sets:
|
||||
if params.window_left != global_params.window_left:
|
||||
raise ValueError(
|
||||
"Window left is not the same for all layers. " \
|
||||
"One potential fix is to set disable_sliding_window=True")
|
||||
assert params == global_params, (
|
||||
"FlashInfer backend currently only supports models in which all"
|
||||
"layers share the same values "
|
||||
"for the following hyperparameters:"
|
||||
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
||||
|
||||
return global_params
|
||||
|
||||
|
||||
Reference in New Issue
Block a user