Fix IndexError with encoder-decoder models when using Custom Paged Attention (#33112)

Signed-off-by: sstamenk <strahinja.stamenkovic@amd.com>
This commit is contained in:
Strahinja Stamenkovic
2026-01-27 03:33:37 +01:00
committed by GitHub
parent 2d7053438a
commit c568581ff3
2 changed files with 13 additions and 5 deletions

View File

@@ -330,7 +330,14 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache, self.num_kv_heads, self.head_size
)
if self.kv_sharing_target_layer_name is None:
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
@@ -382,8 +389,8 @@ class RocmAttentionImpl(AttentionImpl):
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
key=key[:num_actual_tokens] if key is not None else None,
value=value[:num_actual_tokens] if value is not None else None,
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,

View File

@@ -302,8 +302,9 @@ def chunked_prefill_paged_decode(
block_size = value_cache.shape[3]
num_seqs = len(seq_lens)
num_query_heads = query.shape[1]
num_kv_heads = key.shape[1]
num_queries_per_kv = query.shape[1] // key.shape[1]
# key may be None in cross-attention decode (already cached from encoder)
num_kv_heads = key.shape[1] if key is not None else key_cache.shape[1]
num_queries_per_kv = num_query_heads // num_kv_heads
head_size = query.shape[2]
# Conversion of FP8 Tensor from uint8 storage to