From ebed80a7c8c652ff43b5bd910c8fe35d73bfa786 Mon Sep 17 00:00:00 2001 From: Dor Huri <92430368+dorhuri123@users.noreply.github.com> Date: Fri, 6 Mar 2026 02:22:43 +0200 Subject: [PATCH] [Performance] Extract KV-cache update from TreeAttention backend (#35384) Signed-off-by: dorhuri123 --- vllm/v1/attention/backends/tree_attn.py | 47 +++++++++++++++---------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 48082b3a9..2e85109c8 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -31,6 +31,7 @@ logger = init_logger(__name__) class TreeAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + forward_includes_kv_cache_update: bool = False @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: @@ -326,6 +327,33 @@ class TreeAttentionImpl(AttentionImpl): "TreeAttentionImpl." ) + def do_kv_cache_update( + self, + layer: torch.nn.Module, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + key_cache, value_cache = kv_cache.unbind(0) + + # Reshape the input keys and values and store them in the cache. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + def forward( self, layer: torch.nn.Module, @@ -361,26 +389,7 @@ class TreeAttentionImpl(AttentionImpl): # Profiling run. return output.fill_(0) - # Cache the input KVs. key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) num_actual_tokens = attn_metadata.num_actual_tokens num_decode_tokens = attn_metadata.num_decode_tokens