[v1] Add cross-attention KV cache support for encoder-decoder models (#23664)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from typing_extensions import Self
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.utils import cdiv, get_dtype_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -211,6 +212,20 @@ class EncoderOnlyAttentionSpec(AttentionSpec):
|
||||
return 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CrossAttentionSpec(AttentionSpec):
|
||||
"""
|
||||
KV cache spec for cross-attention layers in encoder-decoder models.
|
||||
"""
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
# For cross-attention, we need to cache encoder states
|
||||
# Get encoder length (e.g., 1500 for Whisper).
|
||||
max_encoder_len = MULTIMODAL_REGISTRY.\
|
||||
get_encdec_max_encoder_len(vllm_config.model_config)
|
||||
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheTensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user