[V1] Move more control of kv cache initialization from model_executor to EngineCore (#11960)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
@@ -1,13 +1,20 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar,
|
||||
Union, overload)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List,
|
||||
Optional, TypeVar, Union, overload)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.utils import get_mp_context, kill_process_tree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -134,3 +141,48 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
|
||||
socket_file = ipc_socket.replace("ipc://", "")
|
||||
if os and os.path.exists(socket_file):
|
||||
os.remove(socket_file)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
kv_caches: Dict[str, torch.Tensor],
|
||||
forward_context: Dict[str, "Attention"],
|
||||
runner_kv_caches: List[torch.Tensor],
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches:
|
||||
index2name[extract_layer_index(layer_name)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches.append(kv_caches[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
Reference in New Issue
Block a user