diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index 9589c3128..027012eb1 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, ) -from vllm.v1.attention.backend import AttentionType +from vllm.v1.attention.backend import AttentionLayer, AttentionType from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.rocm_attn import ( RocmAttentionBackend, @@ -24,6 +24,8 @@ logger = init_logger(__name__) class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): accept_output_buffer: bool = True + forward_includes_kv_cache_update: bool = False + @staticmethod def get_name() -> str: return "ROCM_AITER_UNIFIED_ATTN" @@ -142,27 +144,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): key_cache, value_cache = kv_cache.unbind(0) - # key and value may be None in the case of cross attention. They are - # calculated once based on the output from the encoder and then cached - # in KV cache. - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) @@ -204,3 +185,34 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): ) return output + + def do_kv_cache_update( + self, + layer: AttentionLayer, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ): + key_cache, value_cache = kv_cache.unbind(0) + + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index f033ad146..366f08ccc 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -18,6 +18,7 @@ from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, AttentionImpl, + AttentionLayer, AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, @@ -193,6 +194,8 @@ class RocmAttentionBackend(AttentionBackend): "FlexAttention backend which supports all head sizes." ) + forward_includes_kv_cache_update: bool = False + @staticmethod def get_name() -> str: return "ROCM_ATTN" @@ -330,49 +333,6 @@ class RocmAttentionImpl(AttentionImpl): kv_cache, self.num_kv_heads, self.head_size ) - # key and value may be None in the case of cross attention. They are - # calculated once based on the output from the encoder and then cached - # in KV cache. - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - - # Get the actual block_size from value_cache - # value_cache shape: [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] - # Determine if it is a power of 2 - is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0) - - if is_pow2: - # Normal 16, 32, 64, etc., use vLLM native HIP C++ logic - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - else: - # Case B: Non-standard blocks (e.g., 544 in Qwen3), - # force using our modified Triton logic - triton_reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) @@ -410,3 +370,58 @@ class RocmAttentionImpl(AttentionImpl): ) return output + + def do_kv_cache_update( + self, + layer: AttentionLayer, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ): + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size + ) + + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + + # Get the actual block_size from value_cache + # value_cache shape: [num_blocks, num_heads, head_size, block_size] + block_size = value_cache.shape[3] + # Determine if it is a power of 2 + is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0) + + if is_pow2: + # Normal 16, 32, 64, etc., use vLLM native HIP C++ logic + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + # Case B: Non-standard blocks (e.g., 544 in Qwen3), + # force using our modified Triton logic + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index a38b553d8..d091a4e96 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -21,6 +21,7 @@ from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, AttentionImpl, + AttentionLayer, AttentionMetadataBuilder, AttentionType, CommonAttentionMetadata, @@ -271,6 +272,8 @@ class TritonAttentionBackend(AttentionBackend): def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [MultipleOf(16)] + forward_includes_kv_cache_update: bool = False + @staticmethod def get_name() -> str: return "TRITON_ATTN" @@ -461,31 +464,6 @@ class TritonAttentionImpl(AttentionImpl): # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(1) - - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - if self.kv_cache_dtype.startswith("fp8"): - key_cache = key_cache.view(self.fp8_dtype) - value_cache = value_cache.view(self.fp8_dtype) - # triton kernel does not support uint8 kv_cache - # (because some explicit casts (e.g. float8_e4m3fnuz) - # are not supported) - triton_reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - if self.kv_cache_dtype.startswith("fp8"): if key_cache.dtype != self.fp8_dtype: key_cache = key_cache.view(self.fp8_dtype) @@ -585,3 +563,38 @@ class TritonAttentionImpl(AttentionImpl): sliding_window_k=self.sliding_window[1], ) return output + + def do_kv_cache_update( + self, + layer: AttentionLayer, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ): + # For decoder and cross-attention, use KV cache as before + key_cache, value_cache = kv_cache.unbind(1) + + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + # triton kernel does not support uint8 kv_cache + # (because some explicit casts (e.g. float8_e4m3fnuz) + # are not supported) + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + )