Fix trtllm-gen attention env and add attention sink (#22378)
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Lain <fusiyuan2000@hotmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
@@ -215,6 +215,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
# Global hyperparameters shared by all attention layers
|
||||
# TODO: discard this for trtllm-gen backend
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
|
||||
|
||||
@@ -523,16 +524,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
head_dim = self.kv_cache_spec.head_size
|
||||
|
||||
# currently prefill trtllm attention does not support fp8 kv cache
|
||||
# trtllm may not support sliding window
|
||||
prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
|
||||
and not cache_dtype.startswith("fp8")
|
||||
and use_trtllm_attention(
|
||||
prefill_use_trtllm = use_trtllm_attention(
|
||||
num_prefill_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim))
|
||||
decode_use_trtllm = (self.global_hyperparameters.window_left == -1
|
||||
and use_trtllm_attention(
|
||||
num_qo_heads, num_kv_heads, head_dim)
|
||||
decode_use_trtllm = use_trtllm_attention(
|
||||
num_decode_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim))
|
||||
num_qo_heads, num_kv_heads, head_dim)
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -793,6 +790,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
|
||||
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
|
||||
window_left=window_left,
|
||||
sinks=self.sinks,
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
|
||||
@@ -839,6 +838,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
window_left=window_left,
|
||||
sinks=self.sinks,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
return output_padded
|
||||
|
||||
@@ -254,8 +254,7 @@ def get_kv_cache_layout():
|
||||
# Override with format specified by the user.
|
||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||
if cache_layout is None:
|
||||
if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
|
||||
if envs.VLLM_USE_TRTLLM_ATTENTION:
|
||||
cache_layout = "HND"
|
||||
else:
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
@@ -333,8 +332,7 @@ def infer_global_hyperparameters(
|
||||
global_params = param_sets[0]
|
||||
|
||||
# trtllm attention doesn't need global hyper params so disable the check
|
||||
if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
|
||||
and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
|
||||
if not envs.VLLM_USE_TRTLLM_ATTENTION:
|
||||
for params in param_sets:
|
||||
if params.window_left != global_params.window_left:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user