[ROCm] Enabling forward_includes_kv_cache on ROCm MHA backends (#33106)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2026-01-28 00:36:14 -06:00
committed by GitHub
parent 36d450e3b8
commit 22ad649501
3 changed files with 130 additions and 90 deletions

View File

@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionLayer, AttentionType
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import ( from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend, RocmAttentionBackend,
@@ -24,6 +24,8 @@ logger = init_logger(__name__)
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
forward_includes_kv_cache_update: bool = False
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ROCM_AITER_UNIFIED_ATTN" return "ROCM_AITER_UNIFIED_ATTN"
@@ -142,27 +144,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
@@ -204,3 +185,34 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
) )
return output return output
def do_kv_cache_update(
self,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

View File

@@ -18,6 +18,7 @@ from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
AttentionImpl, AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
@@ -193,6 +194,8 @@ class RocmAttentionBackend(AttentionBackend):
"FlexAttention backend which supports all head sizes." "FlexAttention backend which supports all head sizes."
) )
forward_includes_kv_cache_update: bool = False
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ROCM_ATTN" return "ROCM_ATTN"
@@ -330,49 +333,6 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache, self.num_kv_heads, self.head_size kv_cache, self.num_kv_heads, self.head_size
) )
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# Get the actual block_size from value_cache
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# Determine if it is a power of 2
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
if is_pow2:
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
# force using our modified Triton logic
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
@@ -410,3 +370,58 @@ class RocmAttentionImpl(AttentionImpl):
) )
return output return output
def do_kv_cache_update(
self,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# Get the actual block_size from value_cache
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# Determine if it is a power of 2
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
if is_pow2:
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
# force using our modified Triton logic
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

View File

@@ -21,6 +21,7 @@ from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
AttentionImpl, AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
@@ -271,6 +272,8 @@ class TritonAttentionBackend(AttentionBackend):
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)] return [MultipleOf(16)]
forward_includes_kv_cache_update: bool = False
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TRITON_ATTN" return "TRITON_ATTN"
@@ -461,31 +464,6 @@ class TritonAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before # For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1) key_cache, value_cache = kv_cache.unbind(1)
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
if key_cache.dtype != self.fp8_dtype: if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
@@ -585,3 +563,38 @@ class TritonAttentionImpl(AttentionImpl):
sliding_window_k=self.sliding_window[1], sliding_window_k=self.sliding_window[1],
) )
return output return output
def do_kv_cache_update(
self,
layer: AttentionLayer,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)