[ROCm] Enabling forward_includes_kv_cache on ROCm MHA backends (#33106)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
36d450e3b8
commit
22ad649501
@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
QuantKey,
|
QuantKey,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backend import AttentionType
|
from vllm.v1.attention.backend import AttentionLayer, AttentionType
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.attention.backends.rocm_attn import (
|
from vllm.v1.attention.backends.rocm_attn import (
|
||||||
RocmAttentionBackend,
|
RocmAttentionBackend,
|
||||||
@@ -24,6 +24,8 @@ logger = init_logger(__name__)
|
|||||||
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
|
forward_includes_kv_cache_update: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "ROCM_AITER_UNIFIED_ATTN"
|
return "ROCM_AITER_UNIFIED_ATTN"
|
||||||
@@ -142,27 +144,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
|||||||
|
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
|
||||||
# key and value may be None in the case of cross attention. They are
|
|
||||||
# calculated once based on the output from the encoder and then cached
|
|
||||||
# in KV cache.
|
|
||||||
if (
|
|
||||||
self.kv_sharing_target_layer_name is None
|
|
||||||
and key is not None
|
|
||||||
and value is not None
|
|
||||||
):
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
|
||||||
ops.reshape_and_cache_flash(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
self.kv_cache_dtype,
|
|
||||||
layer._k_scale,
|
|
||||||
layer._v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
@@ -204,3 +185,34 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def do_kv_cache_update(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
):
|
||||||
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
|
||||||
|
# key and value may be None in the case of cross attention. They are
|
||||||
|
# calculated once based on the output from the encoder and then cached
|
||||||
|
# in KV cache.
|
||||||
|
if (
|
||||||
|
self.kv_sharing_target_layer_name is None
|
||||||
|
and key is not None
|
||||||
|
and value is not None
|
||||||
|
):
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
|
ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from vllm.v1.attention.backend import (
|
|||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
@@ -193,6 +194,8 @@ class RocmAttentionBackend(AttentionBackend):
|
|||||||
"FlexAttention backend which supports all head sizes."
|
"FlexAttention backend which supports all head sizes."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
forward_includes_kv_cache_update: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "ROCM_ATTN"
|
return "ROCM_ATTN"
|
||||||
@@ -330,49 +333,6 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
kv_cache, self.num_kv_heads, self.head_size
|
kv_cache, self.num_kv_heads, self.head_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# key and value may be None in the case of cross attention. They are
|
|
||||||
# calculated once based on the output from the encoder and then cached
|
|
||||||
# in KV cache.
|
|
||||||
if (
|
|
||||||
self.kv_sharing_target_layer_name is None
|
|
||||||
and key is not None
|
|
||||||
and value is not None
|
|
||||||
):
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
|
||||||
|
|
||||||
# Get the actual block_size from value_cache
|
|
||||||
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
|
|
||||||
block_size = value_cache.shape[3]
|
|
||||||
# Determine if it is a power of 2
|
|
||||||
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
|
||||||
|
|
||||||
if is_pow2:
|
|
||||||
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
|
|
||||||
PagedAttention.write_to_paged_cache(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
self.kv_cache_dtype,
|
|
||||||
layer._k_scale,
|
|
||||||
layer._v_scale,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
|
|
||||||
# force using our modified Triton logic
|
|
||||||
triton_reshape_and_cache_flash(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
self.kv_cache_dtype,
|
|
||||||
layer._k_scale,
|
|
||||||
layer._v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
@@ -410,3 +370,58 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def do_kv_cache_update(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
):
|
||||||
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
|
kv_cache, self.num_kv_heads, self.head_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# key and value may be None in the case of cross attention. They are
|
||||||
|
# calculated once based on the output from the encoder and then cached
|
||||||
|
# in KV cache.
|
||||||
|
if (
|
||||||
|
self.kv_sharing_target_layer_name is None
|
||||||
|
and key is not None
|
||||||
|
and value is not None
|
||||||
|
):
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
|
|
||||||
|
# Get the actual block_size from value_cache
|
||||||
|
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
# Determine if it is a power of 2
|
||||||
|
is_pow2 = block_size > 0 and (block_size & (block_size - 1) == 0)
|
||||||
|
|
||||||
|
if is_pow2:
|
||||||
|
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
|
||||||
|
PagedAttention.write_to_paged_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
|
||||||
|
# force using our modified Triton logic
|
||||||
|
triton_reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from vllm.v1.attention.backend import (
|
|||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
@@ -271,6 +272,8 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
return [MultipleOf(16)]
|
return [MultipleOf(16)]
|
||||||
|
|
||||||
|
forward_includes_kv_cache_update: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TRITON_ATTN"
|
return "TRITON_ATTN"
|
||||||
@@ -461,31 +464,6 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
# For decoder and cross-attention, use KV cache as before
|
# For decoder and cross-attention, use KV cache as before
|
||||||
key_cache, value_cache = kv_cache.unbind(1)
|
key_cache, value_cache = kv_cache.unbind(1)
|
||||||
|
|
||||||
if (
|
|
||||||
self.kv_sharing_target_layer_name is None
|
|
||||||
and key is not None
|
|
||||||
and value is not None
|
|
||||||
):
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
|
||||||
# triton kernel does not support uint8 kv_cache
|
|
||||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
|
||||||
# are not supported)
|
|
||||||
triton_reshape_and_cache_flash(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
self.kv_cache_dtype,
|
|
||||||
layer._k_scale,
|
|
||||||
layer._v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
if key_cache.dtype != self.fp8_dtype:
|
if key_cache.dtype != self.fp8_dtype:
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
@@ -585,3 +563,38 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
sliding_window_k=self.sliding_window[1],
|
sliding_window_k=self.sliding_window[1],
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def do_kv_cache_update(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
):
|
||||||
|
# For decoder and cross-attention, use KV cache as before
|
||||||
|
key_cache, value_cache = kv_cache.unbind(1)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.kv_sharing_target_layer_name is None
|
||||||
|
and key is not None
|
||||||
|
and value is not None
|
||||||
|
):
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
|
# triton kernel does not support uint8 kv_cache
|
||||||
|
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||||
|
# are not supported)
|
||||||
|
triton_reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user