[Core] Support fully transparent sleep mode (#11743)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-22 14:39:32 +08:00
committed by GitHub
parent 4004f144f3
commit 68ad4e3a8d
14 changed files with 877 additions and 40 deletions

View File

@@ -9,12 +9,14 @@ import torch.nn as nn
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
@@ -77,6 +79,23 @@ class Worker:
else:
self.profiler = None
def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self) -> None:
allocator = CuMemAllocator.get_instance()
allocator.wake_up()
def init_device(self):
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
@@ -110,7 +129,17 @@ class Worker:
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
def load_model(self) -> None:
self.model_runner.load_model()
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.load_model()
@torch.inference_mode()
def determine_available_memory(self) -> int:
@@ -167,7 +196,14 @@ class Worker:
def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
self.model_runner.initialize_kv_cache(kv_cache_config)
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager: