[v1] Add cross-attention KV cache support for encoder-decoder models (#23664)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -372,3 +372,22 @@ class MultiModalRegistry:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return dummy_data
|
return dummy_data
|
||||||
|
|
||||||
|
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
|
||||||
|
"""
|
||||||
|
Get the maximum length of the encoder input for encoder-decoder models.
|
||||||
|
"""
|
||||||
|
if not model_config.is_encoder_decoder:
|
||||||
|
return 0
|
||||||
|
max_tokens = self.\
|
||||||
|
get_max_tokens_per_item_by_nonzero_modality(model_config)
|
||||||
|
if not max_tokens:
|
||||||
|
# TODO - this function assumes encoder-decoder models are
|
||||||
|
# multimodal. This will need to change when adding support for more
|
||||||
|
# than whisper.
|
||||||
|
return 0
|
||||||
|
assert len(max_tokens) == 1, "Encoder-decoder models are expected \
|
||||||
|
to implement the multimodal interface with at most one modality."
|
||||||
|
|
||||||
|
first_modality = next(iter(max_tokens))
|
||||||
|
return max_tokens[first_modality]
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||||
FullAttentionManager, get_manager_for_kv_cache_spec)
|
CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec)
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
@@ -42,9 +42,10 @@ class KVCacheCoordinator(ABC):
|
|||||||
) for i, kv_cache_group in enumerate(
|
) for i, kv_cache_group in enumerate(
|
||||||
self.kv_cache_config.kv_cache_groups))
|
self.kv_cache_config.kv_cache_groups))
|
||||||
|
|
||||||
def get_num_blocks_to_allocate(
|
def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int,
|
||||||
self, request_id: str, num_tokens: int,
|
new_computed_blocks: tuple[
|
||||||
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
|
list[KVCacheBlock], ...],
|
||||||
|
num_encoder_tokens: int) -> int:
|
||||||
"""
|
"""
|
||||||
Get the number of blocks needed to be allocated for the request.
|
Get the number of blocks needed to be allocated for the request.
|
||||||
|
|
||||||
@@ -54,12 +55,20 @@ class KVCacheCoordinator(ABC):
|
|||||||
tokens that are already allocated).
|
tokens that are already allocated).
|
||||||
new_computed_blocks: The new computed blocks just hitting the
|
new_computed_blocks: The new computed blocks just hitting the
|
||||||
prefix caching.
|
prefix caching.
|
||||||
|
num_encoder_tokens: The number of encoder tokens for allocating
|
||||||
|
blocks for cross-attention.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of blocks.
|
The number of blocks.
|
||||||
"""
|
"""
|
||||||
num_blocks_to_allocate = 0
|
num_blocks_to_allocate = 0
|
||||||
for i, manager in enumerate(self.single_type_managers):
|
for i, manager in enumerate(self.single_type_managers):
|
||||||
|
if isinstance(manager, CrossAttentionManager):
|
||||||
|
# For cross-attention, we issue a single static allocation
|
||||||
|
# of blocks based on the number of encoder input tokens.
|
||||||
|
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||||
|
request_id, num_encoder_tokens, [])
|
||||||
|
else:
|
||||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||||
request_id, num_tokens, new_computed_blocks[i])
|
request_id, num_tokens, new_computed_blocks[i])
|
||||||
return num_blocks_to_allocate
|
return num_blocks_to_allocate
|
||||||
@@ -79,8 +88,11 @@ class KVCacheCoordinator(ABC):
|
|||||||
manager.save_new_computed_blocks(request_id,
|
manager.save_new_computed_blocks(request_id,
|
||||||
new_computed_blocks[i])
|
new_computed_blocks[i])
|
||||||
|
|
||||||
def allocate_new_blocks(self, request_id: str,
|
def allocate_new_blocks(
|
||||||
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
|
self,
|
||||||
|
request_id: str,
|
||||||
|
num_tokens: int,
|
||||||
|
num_encoder_tokens: int = 0) -> tuple[list[KVCacheBlock], ...]:
|
||||||
"""
|
"""
|
||||||
Allocate new blocks for the request to give it at least `num_tokens`
|
Allocate new blocks for the request to give it at least `num_tokens`
|
||||||
token slots.
|
token slots.
|
||||||
@@ -89,12 +101,16 @@ class KVCacheCoordinator(ABC):
|
|||||||
request_id: The request ID.
|
request_id: The request ID.
|
||||||
num_tokens: The total number of tokens that need a slot (including
|
num_tokens: The total number of tokens that need a slot (including
|
||||||
tokens that are already allocated).
|
tokens that are already allocated).
|
||||||
|
num_encoder_tokens: The number of encoder tokens for allocating
|
||||||
|
blocks for cross-attention.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The new allocated blocks.
|
The new allocated blocks.
|
||||||
"""
|
"""
|
||||||
return tuple(
|
return tuple(
|
||||||
manager.allocate_new_blocks(request_id, num_tokens)
|
manager.allocate_new_blocks(
|
||||||
|
request_id, num_encoder_tokens if isinstance(
|
||||||
|
manager, CrossAttentionManager) else num_tokens)
|
||||||
for manager in self.single_type_managers)
|
for manager in self.single_type_managers)
|
||||||
|
|
||||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||||
|
|||||||
@@ -187,6 +187,7 @@ class KVCacheManager:
|
|||||||
new_computed_blocks: Optional[KVCacheBlocks] = None,
|
new_computed_blocks: Optional[KVCacheBlocks] = None,
|
||||||
num_lookahead_tokens: int = 0,
|
num_lookahead_tokens: int = 0,
|
||||||
delay_cache_blocks: bool = False,
|
delay_cache_blocks: bool = False,
|
||||||
|
num_encoder_tokens: int = 0,
|
||||||
) -> Optional[KVCacheBlocks]:
|
) -> Optional[KVCacheBlocks]:
|
||||||
"""Add slots for a request with new tokens to append.
|
"""Add slots for a request with new tokens to append.
|
||||||
|
|
||||||
@@ -253,6 +254,7 @@ class KVCacheManager:
|
|||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
num_tokens=num_tokens_need_slot,
|
num_tokens=num_tokens_need_slot,
|
||||||
new_computed_blocks=new_computed_block_list,
|
new_computed_blocks=new_computed_block_list,
|
||||||
|
num_encoder_tokens=num_encoder_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||||
@@ -273,7 +275,7 @@ class KVCacheManager:
|
|||||||
new_computed_block_list)
|
new_computed_block_list)
|
||||||
|
|
||||||
new_blocks = self.coordinator.allocate_new_blocks(
|
new_blocks = self.coordinator.allocate_new_blocks(
|
||||||
request.request_id, num_tokens_need_slot)
|
request.request_id, num_tokens_need_slot, num_encoder_tokens)
|
||||||
|
|
||||||
# P/D: delay caching blocks if we have to recv from
|
# P/D: delay caching blocks if we have to recv from
|
||||||
# remote. Update state for locally cached blocks.
|
# remote. Update state for locally cached blocks.
|
||||||
@@ -292,7 +294,7 @@ class KVCacheManager:
|
|||||||
|
|
||||||
def free(self, request: Request) -> None:
|
def free(self, request: Request) -> None:
|
||||||
"""Free the blocks allocated for the request.
|
"""Free the blocks allocated for the request.
|
||||||
We free the blocks in reverse order so that he tail blocks are evicted
|
We free the blocks in reverse order so that the tail blocks are evicted
|
||||||
first when caching is enabled.
|
first when caching is enabled.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.parallel_config = vllm_config.parallel_config
|
self.parallel_config = vllm_config.parallel_config
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
self.structured_output_manager = structured_output_manager
|
self.structured_output_manager = structured_output_manager
|
||||||
|
self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder
|
||||||
|
|
||||||
# include_finished_set controls whether a separate set of finished
|
# include_finished_set controls whether a separate set of finished
|
||||||
# request ids should be included in the EngineCoreOutputs returned
|
# request ids should be included in the EngineCoreOutputs returned
|
||||||
@@ -83,6 +84,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||||
"Multiple KV cache groups are not currently supported "
|
"Multiple KV cache groups are not currently supported "
|
||||||
"with KV connectors")
|
"with KV connectors")
|
||||||
|
assert not self.is_encoder_decoder, (
|
||||||
|
"Encoder-decoder models are not currently supported "
|
||||||
|
"with KV connectors")
|
||||||
self.connector = KVConnectorFactory.create_connector(
|
self.connector = KVConnectorFactory.create_connector(
|
||||||
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
||||||
|
|
||||||
@@ -431,6 +435,22 @@ class Scheduler(SchedulerInterface):
|
|||||||
== 0 else
|
== 0 else
|
||||||
self.num_lookahead_tokens)
|
self.num_lookahead_tokens)
|
||||||
|
|
||||||
|
# Determine if we need to allocate cross-attention blocks.
|
||||||
|
if self.is_encoder_decoder and request.has_encoder_inputs:
|
||||||
|
# TODO(russellb): For Whisper, we know that the input is
|
||||||
|
# always padded to the maximum length. If we support other
|
||||||
|
# encoder-decoder models, this will need to be updated if we
|
||||||
|
# want to only allocate what is needed.
|
||||||
|
assert ("whisper"
|
||||||
|
in self.vllm_config.model_config.model.lower()), (
|
||||||
|
"Whisper is the only supported "
|
||||||
|
"encoder-decoder model.")
|
||||||
|
num_encoder_tokens = MULTIMODAL_REGISTRY.\
|
||||||
|
get_encdec_max_encoder_len(
|
||||||
|
self.vllm_config.model_config)
|
||||||
|
else:
|
||||||
|
num_encoder_tokens = 0
|
||||||
|
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
num_new_tokens + num_external_computed_tokens,
|
num_new_tokens + num_external_computed_tokens,
|
||||||
@@ -438,6 +458,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
new_computed_blocks,
|
new_computed_blocks,
|
||||||
num_lookahead_tokens=effective_lookahead_tokens,
|
num_lookahead_tokens=effective_lookahead_tokens,
|
||||||
delay_cache_blocks=load_kv_async,
|
delay_cache_blocks=load_kv_async,
|
||||||
|
num_encoder_tokens=num_encoder_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if new_blocks is None:
|
if new_blocks is None:
|
||||||
@@ -703,7 +724,21 @@ class Scheduler(SchedulerInterface):
|
|||||||
# The encoder input is not needed in this step.
|
# The encoder input is not needed in this step.
|
||||||
break
|
break
|
||||||
|
|
||||||
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
if self.is_encoder_decoder and num_computed_tokens > 0:
|
||||||
|
assert start_pos == 0, (
|
||||||
|
"Encoder input should be processed at the beginning of "
|
||||||
|
"the sequence when encoder-decoder models are used.")
|
||||||
|
# Encoder input has already been computed
|
||||||
|
# The calculation here is a bit different. We don't turn encoder
|
||||||
|
# output into tokens that get processed by the decoder and
|
||||||
|
# reflected in num_computed_tokens. Instead, start_pos reflects
|
||||||
|
# the position where we need to ensure we calculate encoder
|
||||||
|
# inputs. This should always be 0 to ensure we calculate encoder
|
||||||
|
# inputs before running the decoder. Once we've calculated some
|
||||||
|
# decoder tokens (num_computed_tokens > 0), then we know we
|
||||||
|
# already calculated encoder inputs and can skip here.
|
||||||
|
continue
|
||||||
|
elif start_pos + num_encoder_tokens <= num_computed_tokens:
|
||||||
# The encoder input is already computed and stored
|
# The encoder input is already computed and stored
|
||||||
# in the decoder's KV cache.
|
# in the decoder's KV cache.
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -8,8 +8,9 @@ from vllm.utils import cdiv
|
|||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||||
FullAttentionSpec, KVCacheSpec,
|
CrossAttentionSpec, FullAttentionSpec,
|
||||||
MambaSpec, SlidingWindowSpec)
|
KVCacheSpec, MambaSpec,
|
||||||
|
SlidingWindowSpec)
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
|
||||||
@@ -552,11 +553,62 @@ class MambaManager(SingleTypeKVCacheManager):
|
|||||||
return new_blocks
|
return new_blocks
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||||
|
"""Manager for cross-attention KV cache in encoder-decoder models."""
|
||||||
|
|
||||||
|
def save_new_computed_blocks(
|
||||||
|
self, request_id: str,
|
||||||
|
new_computed_blocks: list[KVCacheBlock]) -> None:
|
||||||
|
# We do not cache blocks for cross-attention to be shared between
|
||||||
|
# requests, so `new_computed_blocks` should always be empty.
|
||||||
|
assert len(new_computed_blocks) == 0
|
||||||
|
|
||||||
|
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||||
|
# We do not cache blocks for cross-attention to be shared between
|
||||||
|
# requests, so this method is not relevant.
|
||||||
|
raise ValueError("Should not be called as prefix caching is disabled.")
|
||||||
|
|
||||||
|
def get_num_common_prefix_blocks(self, request_id: str,
|
||||||
|
num_running_requests: int) -> int:
|
||||||
|
# Cross-attention blocks contain request-specific encoder states
|
||||||
|
# and are not shared between different requests
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find_longest_cache_hit(
|
||||||
|
cls,
|
||||||
|
block_hashes: list[BlockHash],
|
||||||
|
max_length: int,
|
||||||
|
kv_cache_group_ids: list[int],
|
||||||
|
block_pool: BlockPool,
|
||||||
|
kv_cache_spec: KVCacheSpec,
|
||||||
|
use_eagle: bool,
|
||||||
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
|
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
|
||||||
|
"CrossAttentionManager can only be used for cross-attention groups"
|
||||||
|
)
|
||||||
|
# Cross-attention does not benefit from prefix caching since:
|
||||||
|
# 1. Encoder states are unique per request (different audio/image
|
||||||
|
# inputs)
|
||||||
|
# 2. Encoder states are computed once per request, not incrementally
|
||||||
|
# 3. No reusable prefix exists between different multimodal inputs
|
||||||
|
# Return empty blocks to indicate no cache hits
|
||||||
|
raise NotImplementedError(
|
||||||
|
"CrossAttentionManager does not support caching")
|
||||||
|
|
||||||
|
def remove_skipped_blocks(self, request_id: str,
|
||||||
|
num_computed_tokens: int) -> None:
|
||||||
|
# Cross-attention blocks represent encoder states which are needed
|
||||||
|
# for the entire decoding process, so no blocks should be skipped
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||||
FullAttentionSpec: FullAttentionManager,
|
FullAttentionSpec: FullAttentionManager,
|
||||||
SlidingWindowSpec: SlidingWindowManager,
|
SlidingWindowSpec: SlidingWindowManager,
|
||||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||||
MambaSpec: MambaManager,
|
MambaSpec: MambaManager,
|
||||||
|
CrossAttentionSpec: CrossAttentionManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing_extensions import Self
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.utils import cdiv, get_dtype_size
|
from vllm.utils import cdiv, get_dtype_size
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -211,6 +212,20 @@ class EncoderOnlyAttentionSpec(AttentionSpec):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CrossAttentionSpec(AttentionSpec):
|
||||||
|
"""
|
||||||
|
KV cache spec for cross-attention layers in encoder-decoder models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||||
|
# For cross-attention, we need to cache encoder states
|
||||||
|
# Get encoder length (e.g., 1500 for Whisper).
|
||||||
|
max_encoder_len = MULTIMODAL_REGISTRY.\
|
||||||
|
get_encdec_max_encoder_len(vllm_config.model_config)
|
||||||
|
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KVCacheTensor:
|
class KVCacheTensor:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user