[Performance] Extract kv update ops from MLA attention backends (#34627)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Di Wu <dw2761@nyu.edu> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -1007,6 +1007,7 @@ class CompilationConfig:
|
|||||||
# https://github.com/vllm-project/vllm/issues/33267
|
# https://github.com/vllm-project/vllm/issues/33267
|
||||||
if not self.use_inductor_graph_partition:
|
if not self.use_inductor_graph_partition:
|
||||||
self.splitting_ops.append("vllm::unified_kv_cache_update")
|
self.splitting_ops.append("vllm::unified_kv_cache_update")
|
||||||
|
self.splitting_ops.append("vllm::unified_mla_kv_cache_update")
|
||||||
|
|
||||||
elif len(self.splitting_ops) == 0:
|
elif len(self.splitting_ops) == 0:
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -434,7 +434,19 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
if isinstance(attn_metadata, dict):
|
if isinstance(attn_metadata, dict):
|
||||||
attn_metadata = attn_metadata[self.layer_name]
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
slot_mapping = forward_context.slot_mapping
|
||||||
|
|
||||||
|
assert isinstance(slot_mapping, dict), (
|
||||||
|
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||||
|
)
|
||||||
|
self.impl.do_kv_cache_update(
|
||||||
|
kv_c_normed,
|
||||||
|
k_pe,
|
||||||
|
self_kv_cache,
|
||||||
|
slot_mapping.get(self.layer_name),
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self._k_scale,
|
||||||
|
)
|
||||||
if self.attn_backend.accept_output_buffer:
|
if self.attn_backend.accept_output_buffer:
|
||||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||||
self.forward_impl(
|
self.forward_impl(
|
||||||
@@ -451,6 +463,13 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
|
q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update(
|
||||||
|
kv_c_normed,
|
||||||
|
k_pe,
|
||||||
|
self.layer_name,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self._k_scale,
|
||||||
|
)
|
||||||
if self.attn_backend.accept_output_buffer:
|
if self.attn_backend.accept_output_buffer:
|
||||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||||
torch.ops.vllm.unified_mla_attention_with_output(
|
torch.ops.vllm.unified_mla_attention_with_output(
|
||||||
@@ -459,6 +478,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
k_pe,
|
k_pe,
|
||||||
output,
|
output,
|
||||||
self.layer_name,
|
self.layer_name,
|
||||||
|
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
@@ -467,6 +487,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
kv_c_normed,
|
kv_c_normed,
|
||||||
k_pe,
|
k_pe,
|
||||||
self.layer_name,
|
self.layer_name,
|
||||||
|
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_impl(
|
def forward_impl(
|
||||||
@@ -520,17 +541,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||||
k_pe = k_pe[:num_actual_toks, ...]
|
k_pe = k_pe[:num_actual_toks, ...]
|
||||||
|
|
||||||
# write the latent and rope to kv cache
|
|
||||||
if kv_cache.numel() > 0:
|
|
||||||
ops.concat_and_cache_mla(
|
|
||||||
k_c_normed,
|
|
||||||
k_pe.squeeze(1),
|
|
||||||
kv_cache,
|
|
||||||
attn_metadata.slot_mapping.flatten(),
|
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
|
||||||
scale=self._k_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla":
|
if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla":
|
||||||
kv_cache = kv_cache.view(current_platform.fp8_dtype())
|
kv_cache = kv_cache.view(current_platform.fp8_dtype())
|
||||||
|
|
||||||
@@ -827,7 +837,12 @@ def unified_mla_attention(
|
|||||||
kv_c_normed: torch.Tensor,
|
kv_c_normed: torch.Tensor,
|
||||||
k_pe: torch.Tensor,
|
k_pe: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
|
||||||
|
# that ensures torch.compile preserves ordering between KV cache update and
|
||||||
|
# attention forward.
|
||||||
|
del kv_cache_dummy_dep
|
||||||
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
|
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
|
||||||
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
||||||
|
|
||||||
@@ -839,6 +854,7 @@ def unified_mla_attention_fake(
|
|||||||
kv_c_normed: torch.Tensor,
|
kv_c_normed: torch.Tensor,
|
||||||
k_pe: torch.Tensor,
|
k_pe: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty_like(q).contiguous()
|
return torch.empty_like(q).contiguous()
|
||||||
|
|
||||||
@@ -852,6 +868,56 @@ direct_register_custom_op(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_mla_kv_cache_update(
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Returns a dummy that is passed to unified_attention to signal a side effect and
|
||||||
|
the data dependency between them to ensure torch.compile preserves ordering.
|
||||||
|
"""
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||||
|
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||||
|
|
||||||
|
slot_mapping = forward_context.slot_mapping
|
||||||
|
assert isinstance(slot_mapping, dict), (
|
||||||
|
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||||
|
)
|
||||||
|
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||||
|
if layer_slot_mapping is not None:
|
||||||
|
attn_layer.impl.do_kv_cache_update(
|
||||||
|
kv_c_normed,
|
||||||
|
k_pe,
|
||||||
|
kv_cache,
|
||||||
|
layer_slot_mapping,
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_mla_kv_cache_update_fake(
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="unified_mla_kv_cache_update",
|
||||||
|
op_func=unified_mla_kv_cache_update,
|
||||||
|
fake_impl=unified_mla_kv_cache_update_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@maybe_transfer_kv_layer
|
@maybe_transfer_kv_layer
|
||||||
def unified_mla_attention_with_output(
|
def unified_mla_attention_with_output(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@@ -861,7 +927,12 @@ def unified_mla_attention_with_output(
|
|||||||
layer_name: str,
|
layer_name: str,
|
||||||
output_scale: torch.Tensor | None = None,
|
output_scale: torch.Tensor | None = None,
|
||||||
output_block_scale: torch.Tensor | None = None,
|
output_block_scale: torch.Tensor | None = None,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
|
||||||
|
# that ensures torch.compile preserves ordering between KV cache update and
|
||||||
|
# attention forward.
|
||||||
|
del kv_cache_dummy_dep
|
||||||
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
|
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
|
||||||
layer.forward_impl(
|
layer.forward_impl(
|
||||||
q,
|
q,
|
||||||
@@ -883,6 +954,7 @@ def unified_mla_attention_with_output_fake(
|
|||||||
layer_name: str,
|
layer_name: str,
|
||||||
output_scale: torch.Tensor | None = None,
|
output_scale: torch.Tensor | None = None,
|
||||||
output_block_scale: torch.Tensor | None = None,
|
output_block_scale: torch.Tensor | None = None,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -811,6 +811,28 @@ class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
|||||||
"""MQA-style decode forward pass."""
|
"""MQA-style decode forward pass."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def do_kv_cache_update(
|
||||||
|
self,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
if kv_cache.numel() == 0:
|
||||||
|
return
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
ops.concat_and_cache_mla(
|
||||||
|
kv_c_normed,
|
||||||
|
k_pe.squeeze(1),
|
||||||
|
kv_cache,
|
||||||
|
slot_mapping.flatten(),
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
scale=k_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||||
"""Sparse MLA attention implementation with only forward_mqa method.
|
"""Sparse MLA attention implementation with only forward_mqa method.
|
||||||
@@ -856,6 +878,28 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
|||||||
"""MQA-style decode forward pass."""
|
"""MQA-style decode forward pass."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def do_kv_cache_update(
|
||||||
|
self,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
if kv_cache.numel() == 0:
|
||||||
|
return
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
ops.concat_and_cache_mla(
|
||||||
|
kv_c_normed,
|
||||||
|
k_pe.squeeze(1),
|
||||||
|
kv_cache,
|
||||||
|
slot_mapping.flatten(),
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
scale=k_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||||
return kv_cache_dtype.startswith("fp8")
|
return kv_cache_dtype.startswith("fp8")
|
||||||
|
|||||||
Reference in New Issue
Block a user