[Performance] Extract kv update ops from MLA attention backends (#34627)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Di Wu <dw2761@nyu.edu>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
ElizaWszola
2026-03-02 16:43:19 +01:00
committed by GitHub
parent ada4f4fadd
commit d9c7730877
3 changed files with 128 additions and 11 deletions

View File

@@ -811,6 +811,28 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MQA-style decode forward pass."""
raise NotImplementedError
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return
from vllm import _custom_ops as ops
ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""Sparse MLA attention implementation with only forward_mqa method.
@@ -856,6 +878,28 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MQA-style decode forward pass."""
raise NotImplementedError
def do_kv_cache_update(
self,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
) -> None:
if kv_cache.numel() == 0:
return
from vllm import _custom_ops as ops
ops.concat_and_cache_mla(
kv_c_normed,
k_pe.squeeze(1),
kv_cache,
slot_mapping.flatten(),
kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
)
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")