[V1] Support cross-layer KV sharing (#18212)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -59,8 +59,8 @@ from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
@@ -276,6 +276,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
# Layer pairings for cross-layer KV sharing.
|
||||
# If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
# means this layer will perform attention using the keys and values
|
||||
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
self.shared_kv_cache_layers: dict[str, str] = {}
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
@@ -2097,6 +2103,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# KV cache specs.
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
# Setup `kv_cache_config` and `kv_caches` for models
|
||||
# with cross-layer KV sharing
|
||||
if self.shared_kv_cache_layers:
|
||||
initialize_kv_cache_for_kv_sharing(
|
||||
self.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
kv_caches,
|
||||
)
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# validate all draft model layers belong to the same kv cache
|
||||
@@ -2125,6 +2140,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in layers.items():
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
|
||||
@@ -44,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from .utils import sanity_check_mm_encoder_outputs
|
||||
from .utils import (initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -238,6 +239,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.num_reqs_paddings = _get_req_paddings(
|
||||
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
|
||||
|
||||
# Layer pairings for cross-layer KV sharing.
|
||||
# If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
# means this layer will perform attention using the keys and values
|
||||
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
self.shared_kv_cache_layers: dict[str, str] = {}
|
||||
|
||||
# tensors for structured decoding
|
||||
self.grammar_bitmask_cpu = torch.zeros(
|
||||
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
|
||||
@@ -455,6 +462,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in layers.items():
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
@@ -1376,6 +1395,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Setup `kv_cache_config` and `kv_caches` for models
|
||||
# with cross-layer KV sharing
|
||||
if self.shared_kv_cache_layers:
|
||||
initialize_kv_cache_for_kv_sharing(
|
||||
self.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
kv_caches,
|
||||
)
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: object,
|
||||
@@ -73,3 +75,37 @@ def gather_mm_placeholders(
|
||||
return placeholders
|
||||
|
||||
return placeholders[is_embed]
|
||||
|
||||
|
||||
def initialize_kv_cache_for_kv_sharing(
|
||||
shared_kv_cache_layers: dict[str, str],
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
) -> None:
|
||||
"""
|
||||
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
|
||||
for layers that do not allocate its own KV cache, based on the mapping in
|
||||
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
|
||||
group, which is needed to ensure that attention metadata is assigned later.
|
||||
|
||||
Args:
|
||||
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
|
||||
If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
means this layer will perform attention using the keys and values
|
||||
from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
kv_cache_groups: The KV cache groups of the model.
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
Note that layers in shared_kv_cache_layers.keys() are not
|
||||
originally included as it only contains layers which have its own
|
||||
KV cache allocation.
|
||||
"""
|
||||
# Record index of KV cache group for each layer that allocates a KV cache.
|
||||
layer_to_kv_cache_group_idx: dict[str, int] = {}
|
||||
for i, kv_cache_group in enumerate(kv_cache_groups):
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
layer_to_kv_cache_group_idx[layer_name] = i
|
||||
|
||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
|
||||
kv_cache_groups[group_idx].layer_names.append(layer_name)
|
||||
|
||||
Reference in New Issue
Block a user