[V1] Move more control of kv cache initialization from model_executor to EngineCore (#11960)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Chen Zhang
2025-01-17 15:39:35 +08:00
committed by GitHub
parent 8027a72461
commit 69d765f5a5
12 changed files with 514 additions and 103 deletions

View File

@@ -1,13 +1,20 @@
import multiprocessing
import os
import weakref
from collections import defaultdict
from collections.abc import Sequence
from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar,
Union, overload)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List,
Optional, TypeVar, Union, overload)
import torch
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.utils import get_mp_context, kill_process_tree
if TYPE_CHECKING:
from vllm.attention.layer import Attention
logger = init_logger(__name__)
T = TypeVar("T")
@@ -134,3 +141,48 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
socket_file = ipc_socket.replace("ipc://", "")
if os and os.path.exists(socket_file):
os.remove(socket_file)
def bind_kv_cache(
kv_caches: Dict[str, torch.Tensor],
forward_context: Dict[str, "Attention"],
runner_kv_caches: List[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.
This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.
Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0
# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)
for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])
# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]