[ROCm] Support for Whisper v1 with Aiter Unified Attention and Aiter Flash Attention (#28376)
Signed-off-by: apinge <Tong.Qiu2@amd.com>
This commit is contained in:
@@ -517,12 +517,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl"
|
||||
"Encoder self-attention is not implemented for FlashAttentionImpl"
|
||||
)
|
||||
|
||||
def extend_forward(
|
||||
@@ -678,7 +675,14 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
||||
@@ -704,8 +708,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
# decode:extend:prefill
|
||||
query = query[:num_actual_tokens]
|
||||
key = key[:num_actual_tokens]
|
||||
value = value[:num_actual_tokens]
|
||||
if key is not None:
|
||||
key = key[:num_actual_tokens]
|
||||
if value is not None:
|
||||
value = value[:num_actual_tokens]
|
||||
|
||||
output_actual_tokens = output[:num_actual_tokens]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user