[ROCm] Enabling forward_includes_kv_cache on ROCm MHA backends (#33106)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
36d450e3b8
commit
22ad649501
@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
from vllm.v1.attention.backend import AttentionLayer, AttentionType
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.rocm_attn import (
|
||||
RocmAttentionBackend,
|
||||
@@ -24,6 +24,8 @@ logger = init_logger(__name__)
|
||||
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_UNIFIED_ATTN"
|
||||
@@ -142,27 +144,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
@@ -204,3 +185,34 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
@@ -193,6 +194,8 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_ATTN"
|
||||
@@ -330,49 +333,6 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
|
||||
# Get the actual block_size from value_cache
|
||||
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
# Determine if it is a power of 2
|
||||
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
||||
|
||||
if is_pow2:
|
||||
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
|
||||
# force using our modified Triton logic
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
@@ -410,3 +370,58 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
|
||||
# Get the actual block_size from value_cache
|
||||
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
# Determine if it is a power of 2
|
||||
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
||||
|
||||
if is_pow2:
|
||||
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
|
||||
# force using our modified Triton logic
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
@@ -271,6 +272,8 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN"
|
||||
@@ -461,31 +464,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
if key_cache.dtype != self.fp8_dtype:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
@@ -585,3 +563,38 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
sliding_window_k=self.sliding_window[1],
|
||||
)
|
||||
return output
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user