move checks out of unified_kv_cache_update custom op (#33943)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-02-07 07:30:09 -06:00
committed by GitHub
parent ce9b3cd3e9
commit de3869bb4d
7 changed files with 79 additions and 100 deletions

View File

@@ -422,9 +422,15 @@ class Attention(nn.Module, AttentionLayerBase):
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v)
kv_cache_dummy_dep = None
if self.use_direct_call:
kv_cache_dummy_dep = None
if not self.attn_backend.forward_includes_kv_cache_update:
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = unified_kv_cache_update(
key, value, self.layer_name
)
@@ -437,10 +443,12 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
else:
kv_cache_dummy_dep = None
if not self.attn_backend.forward_includes_kv_cache_update and (
# torch can only dispatch custom op if a tensor is passed
key is not None or value is not None
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
key, value, self.layer_name

View File

@@ -136,6 +136,9 @@ def create_cross_attention_backend(
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
and layer.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping