[Performance] Extract kv update ops from MLA attention backends (#34627)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Di Wu <dw2761@nyu.edu>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
ElizaWszola
2026-03-02 16:43:19 +01:00
committed by GitHub
parent ada4f4fadd
commit d9c7730877
3 changed files with 128 additions and 11 deletions

View File

@@ -1007,6 +1007,7 @@ class CompilationConfig:
# https://github.com/vllm-project/vllm/issues/33267
if not self.use_inductor_graph_partition:
self.splitting_ops.append("vllm::unified_kv_cache_update")
self.splitting_ops.append("vllm::unified_mla_kv_cache_update")
elif len(self.splitting_ops) == 0:
if (

View File

@@ -434,7 +434,19 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
self.impl.do_kv_cache_update(
kv_c_normed,
k_pe,
self_kv_cache,
slot_mapping.get(self.layer_name),
self.kv_cache_dtype,
self._k_scale,
)
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.forward_impl(
@@ -451,6 +463,13 @@ class MLAAttention(nn.Module, AttentionLayerBase):
q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
)
else:
kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update(
kv_c_normed,
k_pe,
self.layer_name,
self.kv_cache_dtype,
self._k_scale,
)
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
@@ -459,6 +478,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_pe,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output
else:
@@ -467,6 +487,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
kv_c_normed,
k_pe,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
def forward_impl(
@@ -520,17 +541,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=self._k_scale,
)
if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla":
kv_cache = kv_cache.view(current_platform.fp8_dtype())
@@ -827,7 +837,12 @@ def unified_mla_attention(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> torch.Tensor:
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
@@ -839,6 +854,7 @@ def unified_mla_attention_fake(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()
@@ -852,6 +868,56 @@ direct_register_custom_op(
)
def unified_mla_kv_cache_update(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> torch.Tensor:
"""
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
layer_slot_mapping = slot_mapping.get(layer_name)
if layer_slot_mapping is not None:
attn_layer.impl.do_kv_cache_update(
kv_c_normed,
k_pe,
kv_cache,
layer_slot_mapping,
kv_cache_dtype,
k_scale,
)
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
def unified_mla_kv_cache_update_fake(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> torch.Tensor:
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
direct_register_custom_op(
op_name="unified_mla_kv_cache_update",
op_func=unified_mla_kv_cache_update,
fake_impl=unified_mla_kv_cache_update_fake,
)
@maybe_transfer_kv_layer
def unified_mla_attention_with_output(
q: torch.Tensor,
@@ -861,7 +927,12 @@ def unified_mla_attention_with_output(
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None:
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
layer.forward_impl(
q,
@@ -883,6 +954,7 @@ def unified_mla_attention_with_output_fake(
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None:
return

View File

@@ -811,6 +811,28 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MQA-style decode forward pass."""
raise NotImplementedError
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return
from vllm import _custom_ops as ops
ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""Sparse MLA attention implementation with only forward_mqa method.
@@ -856,6 +878,28 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MQA-style decode forward pass."""
raise NotImplementedError
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return
from vllm import _custom_ops as ops
ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")