[V1] Move more control of kv cache initialization from model_executor to EngineCore (#11960)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Chen Zhang
2025-01-17 15:39:35 +08:00
committed by GitHub
parent 8027a72461
commit 69d765f5a5
12 changed files with 514 additions and 103 deletions

View File

@@ -7,6 +7,8 @@ import torch
import torch.distributed
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
@@ -16,14 +18,16 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.sampling_params import SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, bind_kv_cache, cdiv,
is_pin_memory_available)
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
@@ -856,15 +860,71 @@ class GPUModelRunner:
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
def initialize_kv_cache(self, num_blocks: int) -> None:
assert len(self.kv_caches) == 0
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
for _ in range(self.num_attn_layers):
self.kv_caches.append(
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
kv_caches: Dict[str, torch.Tensor] = {}
for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % layer_spec.page_size_bytes == 0
num_blocks = tensor_config.size // layer_spec.page_size_bytes
if isinstance(layer_spec, FullAttentionSpec):
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
layer_spec.head_size)
dtype = layer_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
else:
raise NotImplementedError
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
[self.kv_caches])
self.kv_caches)
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
forward_ctx = self.vllm_config.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: KVCacheSpec = {}
for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec

View File

@@ -1,7 +1,7 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional
import torch
import torch.distributed
@@ -16,6 +16,7 @@ from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@@ -112,20 +113,18 @@ class Worker:
self.model_runner.load_model()
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Then, it calculate the free memory that can be used for KV cache in
bytes.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
@@ -161,33 +160,14 @@ class Worker:
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
cache_block_size = _get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
return num_gpu_blocks, 0
return int(available_kv_cache_memory)
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks."""
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()
max_seq_len = self.cache_config.block_size * num_gpu_blocks
max_model_len = self.model_config.max_model_len
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
self.model_runner.initialize_kv_cache(num_gpu_blocks)
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager: