From 98217b09f9ce22429ce35badfa1d50e1f4fe4137 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Thu, 26 Feb 2026 22:29:01 +0100 Subject: [PATCH] [Performance] Extract KV cache update op from flashinfer forward (#35422) Signed-off-by: ElizaWszola --- vllm/v1/attention/backends/flashinfer.py | 62 ++++++++++++++---------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 26d372c11..80297720d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -381,6 +381,8 @@ class FlashInferBackend(AttentionBackend): return "HND" return None + forward_includes_kv_cache_update: bool = False + @dataclass class FIPrefill: @@ -1330,32 +1332,15 @@ class FlashInferImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens - if self.kv_sharing_target_layer_name is None: - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith( + "fp8" + ): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.kv_cache_dtype ) - - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 - # to process the cache when the kv_cache_dtype is fp8 - if self.kv_cache_dtype.startswith("fp8"): - torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.kv_cache_dtype - ) - kv_cache = kv_cache.view(torch_dtype) + kv_cache = kv_cache.view(torch_dtype) # Inputs and outputs may be padded for CUDA graphs query = query[:num_actual_tokens] @@ -1599,6 +1584,33 @@ class FlashInferImpl(AttentionImpl): ) return output_padded + def do_kv_cache_update( + self, + layer: torch.nn.Module, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + def fast_plan_decode( self, # decode wrapper