Fix trtllm-gen attention env and add attention sink (#22378)

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Lain <fusiyuan2000@hotmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Lain
2025-08-06 18:07:41 -07:00
committed by GitHub
parent 5c7cc33f4d
commit 9a3835aaa9
5 changed files with 21 additions and 28 deletions

View File

@@ -215,6 +215,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self._cascade_wrapper = None # Wrapper for cascade attention
# Global hyperparameters shared by all attention layers
# TODO: discard this for trtllm-gen backend
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
@@ -523,16 +524,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
head_dim = self.kv_cache_spec.head_size
# currently prefill trtllm attention does not support fp8 kv cache
# trtllm may not support sliding window
prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
and not cache_dtype.startswith("fp8")
and use_trtllm_attention(
prefill_use_trtllm = use_trtllm_attention(
num_prefill_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim))
decode_use_trtllm = (self.global_hyperparameters.window_left == -1
and use_trtllm_attention(
num_qo_heads, num_kv_heads, head_dim)
decode_use_trtllm = use_trtllm_attention(
num_decode_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim))
num_qo_heads, num_kv_heads, head_dim)
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
@@ -793,6 +790,8 @@ class FlashInferImpl(AttentionImpl):
batch_size=attn_metadata.num_prefills,
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
window_left=window_left,
sinks=self.sinks,
out=output[num_decode_tokens:],
)
@@ -839,6 +838,8 @@ class FlashInferImpl(AttentionImpl):
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=layer._k_scale_float * self.scale,
bmm2_scale=layer._v_scale_float,
window_left=window_left,
sinks=self.sinks,
out=output[:num_decode_tokens],
)
return output_padded

View File

@@ -254,8 +254,7 @@ def get_kv_cache_layout():
# Override with format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
if cache_layout is None:
if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
if envs.VLLM_USE_TRTLLM_ATTENTION:
cache_layout = "HND"
else:
cache_layout = get_kv_connector_cache_layout()
@@ -333,8 +332,7 @@ def infer_global_hyperparameters(
global_params = param_sets[0]
# trtllm attention doesn't need global hyper params so disable the check
if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
if not envs.VLLM_USE_TRTLLM_ATTENTION:
for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(