move checks out of unified_kv_cache_update custom op (#33943)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -422,9 +422,15 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
kv_cache_dummy_dep = None
|
||||
if self.use_direct_call:
|
||||
kv_cache_dummy_dep = None
|
||||
if not self.attn_backend.forward_includes_kv_cache_update:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
@@ -437,10 +443,12 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
else:
|
||||
kv_cache_dummy_dep = None
|
||||
if not self.attn_backend.forward_includes_kv_cache_update and (
|
||||
# torch can only dispatch custom op if a tensor is passed
|
||||
key is not None or value is not None
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
|
||||
@@ -136,6 +136,9 @@ def create_cross_attention_backend(
|
||||
if (
|
||||
not underlying_attn_backend.forward_includes_kv_cache_update
|
||||
and attn_metadata is not None
|
||||
and layer.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
self.do_kv_cache_update(
|
||||
layer, key, value, kv_cache, attn_metadata.slot_mapping
|
||||
|
||||
@@ -172,6 +172,9 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
if (
|
||||
not underlying_attn_backend.forward_includes_kv_cache_update
|
||||
and attn_metadata is not None
|
||||
and layer.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
self.do_kv_cache_update(
|
||||
layer, key, value, kv_cache, attn_metadata.slot_mapping
|
||||
|
||||
@@ -771,16 +771,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is not None
|
||||
or key is None
|
||||
or value is None
|
||||
):
|
||||
return
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
|
||||
@@ -196,23 +196,14 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
):
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
@@ -383,45 +383,35 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Get the actual block_size from value_cache
|
||||
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
# Determine if it is a power of 2
|
||||
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
||||
|
||||
# Get the actual block_size from value_cache
|
||||
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
# Determine if it is a power of 2
|
||||
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
||||
|
||||
if is_pow2:
|
||||
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
|
||||
# force using our modified Triton logic
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
if is_pow2:
|
||||
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
|
||||
# force using our modified Triton logic
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
@@ -579,26 +579,20 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user