[ROCm] Enabling forward_includes_kv_cache on ROCm MHA backends (#33106)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
36d450e3b8
commit
22ad649501
@@ -21,6 +21,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
@@ -271,6 +272,8 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN"
|
||||
@@ -461,31 +464,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
if key_cache.dtype != self.fp8_dtype:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
@@ -585,3 +563,38 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
sliding_window_k=self.sliding_window[1],
|
||||
)
|
||||
return output
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user