[Performance] Extract kv update ops from MLA attention backends (#34627)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Di Wu <dw2761@nyu.edu> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -811,6 +811,28 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""MQA-style decode forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if kv_cache.numel() == 0:
|
||||
return
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
)
|
||||
|
||||
|
||||
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""Sparse MLA attention implementation with only forward_mqa method.
|
||||
@@ -856,6 +878,28 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""MQA-style decode forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: torch.Tensor,
|
||||
) -> None:
|
||||
if kv_cache.numel() == 0:
|
||||
return
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
)
|
||||
|
||||
|
||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||
return kv_cache_dtype.startswith("fp8")
|
||||
|
||||
Reference in New Issue
Block a user