From d9c77308776b4d31f03fad8d4129a3d539154166 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 2 Mar 2026 16:43:19 +0100 Subject: [PATCH] [Performance] Extract kv update ops from MLA attention backends (#34627) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: ElizaWszola Signed-off-by: Luka Govedič Co-authored-by: Di Wu Co-authored-by: Luka Govedič --- vllm/config/compilation.py | 1 + .../layers/attention/mla_attention.py | 94 ++++++++++++++++--- vllm/v1/attention/backend.py | 44 +++++++++ 3 files changed, 128 insertions(+), 11 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 54dbf24f5..64332d2e8 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -1007,6 +1007,7 @@ class CompilationConfig: # https://github.com/vllm-project/vllm/issues/33267 if not self.use_inductor_graph_partition: self.splitting_ops.append("vllm::unified_kv_cache_update") + self.splitting_ops.append("vllm::unified_mla_kv_cache_update") elif len(self.splitting_ops) == 0: if ( diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index f6e7ab85d..820755b9c 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -434,7 +434,19 @@ class MLAAttention(nn.Module, AttentionLayerBase): if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] + slot_mapping = forward_context.slot_mapping + assert isinstance(slot_mapping, dict), ( + f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " + ) + self.impl.do_kv_cache_update( + kv_c_normed, + k_pe, + self_kv_cache, + slot_mapping.get(self.layer_name), + self.kv_cache_dtype, + self._k_scale, + ) if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) self.forward_impl( @@ -451,6 +463,13 @@ class MLAAttention(nn.Module, AttentionLayerBase): q, kv_c_normed, k_pe, self_kv_cache, attn_metadata ) else: + kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update( + kv_c_normed, + k_pe, + self.layer_name, + self.kv_cache_dtype, + self._k_scale, + ) if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) torch.ops.vllm.unified_mla_attention_with_output( @@ -459,6 +478,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): k_pe, output, self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) return output else: @@ -467,6 +487,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): kv_c_normed, k_pe, self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) def forward_impl( @@ -520,17 +541,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - # write the latent and rope to kv cache - if kv_cache.numel() > 0: - ops.concat_and_cache_mla( - k_c_normed, - k_pe.squeeze(1), - kv_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype=self.kv_cache_dtype, - scale=self._k_scale, - ) - if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla": kv_cache = kv_cache.view(current_platform.fp8_dtype()) @@ -827,7 +837,12 @@ def unified_mla_attention( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: + # kv_cache_dummy_dep is not used but accepting it creates a data dependency + # that ensures torch.compile preserves ordering between KV cache update and + # attention forward. + del kv_cache_dummy_dep attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata) @@ -839,6 +854,7 @@ def unified_mla_attention_fake( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(q).contiguous() @@ -852,6 +868,56 @@ direct_register_custom_op( ) +def unified_mla_kv_cache_update( + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, + kv_cache_dtype: str, + k_scale: torch.Tensor, +) -> torch.Tensor: + """ + Returns a dummy that is passed to unified_attention to signal a side effect and + the data dependency between them to ensure torch.compile preserves ordering. + """ + forward_context = get_forward_context() + attn_layer = forward_context.no_compile_layers[layer_name] + kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] + + slot_mapping = forward_context.slot_mapping + assert isinstance(slot_mapping, dict), ( + f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " + ) + layer_slot_mapping = slot_mapping.get(layer_name) + if layer_slot_mapping is not None: + attn_layer.impl.do_kv_cache_update( + kv_c_normed, + k_pe, + kv_cache, + layer_slot_mapping, + kv_cache_dtype, + k_scale, + ) + + return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) + + +def unified_mla_kv_cache_update_fake( + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, + kv_cache_dtype: str, + k_scale: torch.Tensor, +) -> torch.Tensor: + return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) + + +direct_register_custom_op( + op_name="unified_mla_kv_cache_update", + op_func=unified_mla_kv_cache_update, + fake_impl=unified_mla_kv_cache_update_fake, +) + + @maybe_transfer_kv_layer def unified_mla_attention_with_output( q: torch.Tensor, @@ -861,7 +927,12 @@ def unified_mla_attention_with_output( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> None: + # kv_cache_dummy_dep is not used but accepting it creates a data dependency + # that ensures torch.compile preserves ordering between KV cache update and + # attention forward. + del kv_cache_dummy_dep attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) layer.forward_impl( q, @@ -883,6 +954,7 @@ def unified_mla_attention_with_output_fake( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> None: return diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 43fa59911..585ad1d79 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -811,6 +811,28 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]): """MQA-style decode forward pass.""" raise NotImplementedError + def do_kv_cache_update( + self, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + ) -> None: + if kv_cache.numel() == 0: + return + from vllm import _custom_ops as ops + + ops.concat_and_cache_mla( + kv_c_normed, + k_pe.squeeze(1), + kv_cache, + slot_mapping.flatten(), + kv_cache_dtype=kv_cache_dtype, + scale=k_scale, + ) + class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]): """Sparse MLA attention implementation with only forward_mqa method. @@ -856,6 +878,28 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]): """MQA-style decode forward pass.""" raise NotImplementedError + def do_kv_cache_update( + self, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + ) -> None: + if kv_cache.numel() == 0: + return + from vllm import _custom_ops as ops + + ops.concat_and_cache_mla( + kv_c_normed, + k_pe.squeeze(1), + kv_cache, + slot_mapping.flatten(), + kv_cache_dtype=kv_cache_dtype, + scale=k_scale, + ) + def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: return kv_cache_dtype.startswith("fp8")