[v1] Add cross-attention KV cache support for encoder-decoder models (#23664)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant
2025-08-26 14:49:06 -04:00
committed by GitHub
parent 227e231b55
commit 98aa16ff41
6 changed files with 153 additions and 14 deletions

View File

@@ -11,6 +11,7 @@ from typing_extensions import Self
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import cdiv, get_dtype_size
logger = init_logger(__name__)
@@ -211,6 +212,20 @@ class EncoderOnlyAttentionSpec(AttentionSpec):
return 0
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
"""
KV cache spec for cross-attention layers in encoder-decoder models.
"""
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# For cross-attention, we need to cache encoder states
# Get encoder length (e.g., 1500 for Whisper).
max_encoder_len = MULTIMODAL_REGISTRY.\
get_encdec_max_encoder_len(vllm_config.model_config)
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
@dataclass
class KVCacheTensor:
"""