diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 170de6a87..8c3ff3cc4 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -422,9 +422,15 @@ class Attention(nn.Module, AttentionLayerBase): key = key.view(-1, self.num_kv_heads, self.head_size) if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size_v) + kv_cache_dummy_dep = None if self.use_direct_call: - kv_cache_dummy_dep = None - if not self.attn_backend.forward_includes_kv_cache_update: + # Skip this if sharing KV cache with an earlier attention layer. + if ( + not self.attn_backend.forward_includes_kv_cache_update + and self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): kv_cache_dummy_dep = unified_kv_cache_update( key, value, self.layer_name ) @@ -437,10 +443,12 @@ class Attention(nn.Module, AttentionLayerBase): kv_cache_dummy_dep=kv_cache_dummy_dep, ) else: - kv_cache_dummy_dep = None - if not self.attn_backend.forward_includes_kv_cache_update and ( - # torch can only dispatch custom op if a tensor is passed - key is not None or value is not None + # Skip this if sharing KV cache with an earlier attention layer. + if ( + not self.attn_backend.forward_includes_kv_cache_update + and self.kv_sharing_target_layer_name is None + and key is not None + and value is not None ): kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( key, value, self.layer_name diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index 6a829db26..9333b35e6 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -136,6 +136,9 @@ def create_cross_attention_backend( if ( not underlying_attn_backend.forward_includes_kv_cache_update and attn_metadata is not None + and layer.kv_sharing_target_layer_name is None + and key is not None + and value is not None ): self.do_kv_cache_update( layer, key, value, kv_cache, attn_metadata.slot_mapping diff --git a/vllm/model_executor/models/whisper_causal.py b/vllm/model_executor/models/whisper_causal.py index 4f2d4c07c..c43c00840 100644 --- a/vllm/model_executor/models/whisper_causal.py +++ b/vllm/model_executor/models/whisper_causal.py @@ -172,6 +172,9 @@ def create_whisper_attention_backend_with_block_pooling( if ( not underlying_attn_backend.forward_includes_kv_cache_update and attn_metadata is not None + and layer.kv_sharing_target_layer_name is None + and key is not None + and value is not None ): self.do_kv_cache_update( layer, key, value, kv_cache, attn_metadata.slot_mapping diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 927572531..e786ab3bc 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -771,16 +771,6 @@ class FlashAttentionImpl(AttentionImpl): # we use direct Q, K, V tensors without caching return - # key and value may be None in the case of cross attention. They are - # calculated once based on the output from the encoder and then cached - # in KV cache. - if ( - self.kv_sharing_target_layer_name is not None - or key is None - or value is None - ): - return - key_cache, value_cache = kv_cache.unbind(0) # Reshape the input keys and values and store them in the cache. diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index 027012eb1..3d8a660c9 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -196,23 +196,14 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): ): key_cache, value_cache = kv_cache.unbind(0) - # key and value may be None in the case of cross attention. They are - # calculated once based on the output from the encoder and then cached - # in KV cache. - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + # Reshape the input keys and values and store them in the cache. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 366f08ccc..0b9889c13 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -383,45 +383,35 @@ class RocmAttentionImpl(AttentionImpl): kv_cache, self.num_kv_heads, self.head_size ) - # key and value may be None in the case of cross attention. They are - # calculated once based on the output from the encoder and then cached - # in KV cache. - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. + # Reshape the input keys and values and store them in the cache. + # Get the actual block_size from value_cache + # value_cache shape: [num_blocks, num_heads, head_size, block_size] + block_size = value_cache.shape[3] + # Determine if it is a power of 2 + is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0) - # Get the actual block_size from value_cache - # value_cache shape: [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] - # Determine if it is a power of 2 - is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0) - - if is_pow2: - # Normal 16, 32, 64, etc., use vLLM native HIP C++ logic - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - else: - # Case B: Non-standard blocks (e.g., 544 in Qwen3), - # force using our modified Triton logic - triton_reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + if is_pow2: + # Normal 16, 32, 64, etc., use vLLM native HIP C++ logic + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + # Case B: Non-standard blocks (e.g., 544 in Qwen3), + # force using our modified Triton logic + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 96429e29b..c0987dbe4 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -579,26 +579,20 @@ class TritonAttentionImpl(AttentionImpl): # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(1) - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(self.fp8_dtype) - value_cache = value_cache.view(self.fp8_dtype) - # triton kernel does not support uint8 kv_cache - # (because some explicit casts (e.g. float8_e4m3fnuz) - # are not supported) - triton_reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + # Reshape the input keys and values and store them in the cache. + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + # triton kernel does not support uint8 kv_cache + # (because some explicit casts (e.g. float8_e4m3fnuz) + # are not supported) + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + )