From c568581ff38af13924627fdd5c2cd9d9282fbe3d Mon Sep 17 00:00:00 2001 From: Strahinja Stamenkovic Date: Tue, 27 Jan 2026 03:33:37 +0100 Subject: [PATCH] Fix IndexError with encoder-decoder models when using Custom Paged Attention (#33112) Signed-off-by: sstamenk --- vllm/v1/attention/backends/rocm_attn.py | 13 ++++++++++--- .../attention/ops/chunked_prefill_paged_decode.py | 5 +++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 73747aaed..f033ad146 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -330,7 +330,14 @@ class RocmAttentionImpl(AttentionImpl): kv_cache, self.num_kv_heads, self.head_size ) - if self.kv_sharing_target_layer_name is None: + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. @@ -382,8 +389,8 @@ class RocmAttentionImpl(AttentionImpl): # Compute attention and update output up to `num_actual_tokens`. chunked_prefill_paged_decode( query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], + key=key[:num_actual_tokens] if key is not None else None, + value=value[:num_actual_tokens] if value is not None else None, output=output[:num_actual_tokens], kv_cache_dtype=self.kv_cache_dtype, key_cache=key_cache, diff --git a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py index c8b25d387..2dbd8755b 100644 --- a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @@ -302,8 +302,9 @@ def chunked_prefill_paged_decode( block_size = value_cache.shape[3] num_seqs = len(seq_lens) num_query_heads = query.shape[1] - num_kv_heads = key.shape[1] - num_queries_per_kv = query.shape[1] // key.shape[1] + # key may be None in cross-attention decode (already cached from encoder) + num_kv_heads = key.shape[1] if key is not None else key_cache.shape[1] + num_queries_per_kv = num_query_heads // num_kv_heads head_size = query.shape[2] # Conversion of FP8 Tensor from uint8 storage to