Fix IndexError with encoder-decoder models when using Custom Paged Attention (#33112)
Signed-off-by: sstamenk <strahinja.stamenkovic@amd.com>
This commit is contained in:
committed by
GitHub
parent
2d7053438a
commit
c568581ff3
@@ -330,7 +330,14 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
kv_cache, self.num_kv_heads, self.head_size
|
kv_cache, self.num_kv_heads, self.head_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.kv_sharing_target_layer_name is None:
|
# key and value may be None in the case of cross attention. They are
|
||||||
|
# calculated once based on the output from the encoder and then cached
|
||||||
|
# in KV cache.
|
||||||
|
if (
|
||||||
|
self.kv_sharing_target_layer_name is None
|
||||||
|
and key is not None
|
||||||
|
and value is not None
|
||||||
|
):
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
|
|
||||||
@@ -382,8 +389,8 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
# Compute attention and update output up to `num_actual_tokens`.
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
chunked_prefill_paged_decode(
|
chunked_prefill_paged_decode(
|
||||||
query=query[:num_actual_tokens],
|
query=query[:num_actual_tokens],
|
||||||
key=key[:num_actual_tokens],
|
key=key[:num_actual_tokens] if key is not None else None,
|
||||||
value=value[:num_actual_tokens],
|
value=value[:num_actual_tokens] if value is not None else None,
|
||||||
output=output[:num_actual_tokens],
|
output=output[:num_actual_tokens],
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
key_cache=key_cache,
|
key_cache=key_cache,
|
||||||
|
|||||||
@@ -302,8 +302,9 @@ def chunked_prefill_paged_decode(
|
|||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
num_seqs = len(seq_lens)
|
num_seqs = len(seq_lens)
|
||||||
num_query_heads = query.shape[1]
|
num_query_heads = query.shape[1]
|
||||||
num_kv_heads = key.shape[1]
|
# key may be None in cross-attention decode (already cached from encoder)
|
||||||
num_queries_per_kv = query.shape[1] // key.shape[1]
|
num_kv_heads = key.shape[1] if key is not None else key_cache.shape[1]
|
||||||
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
head_size = query.shape[2]
|
head_size = query.shape[2]
|
||||||
|
|
||||||
# Conversion of FP8 Tensor from uint8 storage to
|
# Conversion of FP8 Tensor from uint8 storage to
|
||||||
|
|||||||
Reference in New Issue
Block a user