From 747431044df6b15c7b359b5720cc7368c662c232 Mon Sep 17 00:00:00 2001 From: cong-or Date: Mon, 9 Mar 2026 03:40:12 +0000 Subject: [PATCH] feat(attention): extract KV-cache update from FlexAttention backend (#36263) Signed-off-by: cong-or --- vllm/v1/attention/backends/flex_attention.py | 36 ++++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 687e2ba1d..2f67a2d53 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -82,6 +82,8 @@ class FlexAttentionBackend(AttentionBackend): ] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"] + forward_includes_kv_cache_update: bool = False + @staticmethod def get_name() -> str: return "FLEX_ATTENTION" @@ -827,6 +829,29 @@ class FlexAttentionImpl(AttentionImpl): assert tensor.ndim == 3 return tensor[None, :, :, :] + def do_kv_cache_update( + self, + layer: torch.nn.Module, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + if self.attn_type == AttentionType.ENCODER_ONLY: + return + + key_cache, value_cache = kv_cache.unbind(0) + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + def forward( self, layer: torch.nn.Module, @@ -908,17 +933,6 @@ class FlexAttentionImpl(AttentionImpl): assert self.attn_type == AttentionType.DECODER key_cache, value_cache = kv_cache.unbind(0) - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - # View out the block_size dim key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)