[Feat][v1] Simple yet General CPU KV Cache Offloading (#37160)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
This commit is contained in:
Yifan Qiao
2026-03-31 17:58:37 -07:00
committed by GitHub
parent 31a719bcd3
commit 91e4521f9f
15 changed files with 2964 additions and 3 deletions

View File

View File

@@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for SimpleCPUOffloadConnector with real models."""
import time
import pytest
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVTransferConfig
from vllm.platforms import current_platform
if not current_platform.is_cuda():
pytest.skip("Requires CUDA", allow_module_level=True)
# Small models for default CI / local runs (accuracy only).
SMALL_MODELS = [
"meta-llama/Llama-3.2-1B-Instruct",
"google/gemma-3-1b-it",
]
# Large models for optional perf runs only (slow to load and execute).
PERF_MODELS = [
"meta-llama/Llama-3.1-8B",
"openai/gpt-oss-20b",
]
def _make_llm(model: str, lazy: bool, cpu_bytes_to_use: int) -> LLM:
kv_transfer_config = KVTransferConfig(
kv_connector="SimpleCPUOffloadConnector",
kv_role="kv_both",
kv_connector_extra_config={
"cpu_bytes_to_use": cpu_bytes_to_use,
"lazy_offload": lazy,
},
)
return LLM(
model=model,
kv_cache_memory_bytes=40 << 30, # 40 GiB
disable_hybrid_kv_cache_manager=False,
enable_prefix_caching=True,
kv_transfer_config=kv_transfer_config,
)
def _flush_gpu_cache(llm: LLM, sampling_params: SamplingParams, seed: int = 0):
"""Generate enough filler requests to allocate the entire GPU KV cache.
This pushes all prior blocks through the free queue so that the lazy
cursor offloads them to CPU before they are evicted.
"""
cache_config = llm.llm_engine.vllm_config.cache_config
num_gpu_blocks = cache_config.num_gpu_blocks
block_size = cache_config.block_size
# Use 1.2x GPU capacity to give the lazy cursor enough scheduling steps
# to walk past all target blocks near the tail of the free queue.
total_tokens_needed = int(num_gpu_blocks * block_size * 1.5)
# Use token-id prompts so each filler is unique (no prefix sharing).
# Split into multiple requests to stay under max_model_len.
max_tokens_per_req = 4096
num_fillers = (total_tokens_needed + max_tokens_per_req - 1) // max_tokens_per_req
batch_size = 10
for i in range(0, num_fillers, batch_size):
batch_end = min(i + batch_size, num_fillers)
filler_prompts = []
for j in range(i, batch_end):
ids = [seed * num_fillers + j + 1] * max_tokens_per_req
filler_prompts.append(TokensPrompt(prompt_token_ids=ids))
llm.generate(filler_prompts, sampling_params, use_tqdm=False)
def _accuracy_test(llm: LLM, lazy: bool = False):
"""Verify that CPU-loaded KV produces correct output."""
sampling_params = SamplingParams(max_tokens=1, temperature=0)
prompt = "hi " * 2000 + "Let's count to ten. One, two, three, "
# Cold run — populate GPU cache and trigger CPU offload
cold_output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
# CPU hit runs
test_count = 10
success_count = 0
expected = cold_output.outputs[0].text
for i in range(test_count):
if lazy:
_flush_gpu_cache(llm, sampling_params, seed=i)
time.sleep(2) # let engine core drain pending transfers
# Reset GPU prefix cache so next run must load from CPU
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")
output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
if output.outputs[0].text == expected:
success_count += 1
assert success_count >= 0.5 * test_count, (
f"Accuracy too low: {success_count}/{test_count} matched '{expected}'"
)
def _latency_test(llm: LLM, lazy: bool = False):
"""Verify CPU cache hit is faster than cold compute."""
sampling_params = SamplingParams(max_tokens=1, seed=42)
prompt_token_ids = [0] * 10001
num_times_cpu_better = 0
num_tests = 10
for i in range(num_tests):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
# Cold
time.sleep(2) # let engine core drain pending transfers
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")
start = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start
if lazy:
_flush_gpu_cache(llm, sampling_params, seed=i)
else:
# Eager mode: GPU hit ensures store completion is processed.
llm.generate(prompts, sampling_params, use_tqdm=False)
time.sleep(2) # let engine core drain pending transfers
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")
# CPU hit
start = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_time = time.time() - start
if cpu_time < cold_time:
num_times_cpu_better += 1
assert num_times_cpu_better >= 0.8 * num_tests, (
f"CPU hit only faster {num_times_cpu_better}/{num_tests} times"
)
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", SMALL_MODELS)
def test_simple_cpu_offload_accuracy(model: str):
"""Store to CPU, reset GPU, load from CPU; verify output matches baseline."""
llm = _make_llm(model, False, 1 << 30) # 1GB
try:
_accuracy_test(llm, lazy=False)
finally:
del llm
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", PERF_MODELS)
def test_simple_cpu_offload_perf_latency(model: str):
"""CPU KV hit should beat cold prefill on long context (large models only)."""
llm = _make_llm(model, False, 10 << 30) # 10GB
try:
_latency_test(llm, lazy=False)
finally:
del llm
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", SMALL_MODELS)
def test_simple_cpu_offload_accuracy_lazy(model: str):
"""Lazy mode: flush GPU cache to trigger CPU offload, then verify hit."""
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
llm = _make_llm(model, True, 80 << 30) # 80GB
try:
_accuracy_test(llm, lazy=True)
finally:
del llm
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", PERF_MODELS)
def test_simple_cpu_offload_perf_latency_lazy(model: str):
"""Lazy mode: CPU KV hit should beat cold prefill (large models only)."""
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
llm = _make_llm(model, True, 80 << 30) # 80GB
try:
_latency_test(llm, lazy=True)
finally:
del llm

File diff suppressed because it is too large Load Diff

View File

@@ -660,7 +660,11 @@ class VllmConfig:
)
if kv_offloading_backend == "native":
self.kv_transfer_config.kv_connector = "OffloadingConnector"
if envs.VLLM_USE_SIMPLE_KV_OFFLOAD:
config_connector = "SimpleCPUOffloadConnector"
else:
config_connector = "OffloadingConnector"
self.kv_transfer_config.kv_connector = config_connector
self.kv_transfer_config.kv_connector_extra_config.update(
{"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
)

View File

@@ -202,6 +202,7 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
"DecodeBenchConnector",
)
KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector",
@@ -213,3 +214,9 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector",
"FlexKVConnectorV1",
)
KVConnectorFactory.register_connector(
"SimpleCPUOffloadConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector",
"SimpleCPUOffloadConnector",
)

View File

@@ -0,0 +1,247 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""SimpleCPUOffloadConnector: minimal CPU KV cache offloading."""
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.simple_kv_offload.manager import (
SimpleCPUOffloadScheduler,
)
from vllm.v1.simple_kv_offload.metadata import (
SimpleCPUOffloadMetadata,
)
from vllm.v1.simple_kv_offload.worker import (
SimpleCPUOffloadWorker,
)
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
# Default CPU capacity: 8 GB
DEFAULT_CPU_CAPACITY_BYTES = 8 * (1024**3)
class SimpleCPUOffloadConnector(KVConnectorBase_V1, SupportsHMA):
"""CPU KV cache offloading with custom kernel transfers and BlockPool LRU."""
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
):
super().__init__(vllm_config, role, kv_cache_config)
enable_prefix_caching = vllm_config.cache_config.enable_prefix_caching
extra_config = self._kv_transfer_config.kv_connector_extra_config or {}
cpu_capacity_bytes = int(
extra_config.get("cpu_bytes_to_use", DEFAULT_CPU_CAPACITY_BYTES)
)
# cpu_bytes_to_use is server-wide for compatibility;
# cpu_bytes_to_use_per_rank overrides for per-rank capacity.
world_size = vllm_config.parallel_config.world_size
cpu_capacity_per_rank = cpu_capacity_bytes // world_size
if "cpu_bytes_to_use_per_rank" in extra_config:
explicit = int(extra_config["cpu_bytes_to_use_per_rank"])
if explicit != cpu_capacity_per_rank:
logger.warning(
"cpu_bytes_to_use_per_rank (%.2f GB) != "
"cpu_bytes_to_use/world_size (%.2f GB). Using per-rank value.",
explicit / (1024**3),
cpu_capacity_per_rank / (1024**3),
)
cpu_capacity_per_rank = explicit
lazy_offload = bool(extra_config.get("lazy_offload", False))
self.scheduler_manager: SimpleCPUOffloadScheduler | None = None
self.worker_handler: SimpleCPUOffloadWorker | None = None
if not enable_prefix_caching:
logger.warning(
"Detected prefix caching disabled, disabling CPU offload "
"since it requires prefix caching."
)
return
logger.info(
"SimpleCPUOffloadConnector: role=%s, "
"per_rank=%.2f GB, world_size=%d, mode=%s",
role.name,
cpu_capacity_per_rank / (1024**3),
world_size,
"lazy" if lazy_offload else "eager",
)
if role == KVConnectorRole.SCHEDULER:
self.scheduler_manager = SimpleCPUOffloadScheduler(
vllm_config,
kv_cache_config,
cpu_capacity_per_rank,
lazy_offload=lazy_offload,
)
elif role == KVConnectorRole.WORKER:
self.worker_handler = SimpleCPUOffloadWorker(
vllm_config, kv_cache_config, cpu_capacity_per_rank
)
# --- Worker-side methods ---
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
if self.worker_handler is not None:
self.worker_handler.register_kv_caches(kv_caches)
def bind_connector_metadata(
self,
connector_metadata: KVConnectorMetadata,
) -> None:
super().bind_connector_metadata(connector_metadata)
if self.worker_handler is not None:
assert isinstance(connector_metadata, SimpleCPUOffloadMetadata)
self.worker_handler.bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
super().clear_connector_metadata()
if self.worker_handler is not None:
self.worker_handler.clear_connector_metadata()
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata) -> None:
if self.worker_handler is not None:
assert isinstance(kv_connector_metadata, SimpleCPUOffloadMetadata)
self.worker_handler.handle_preemptions(kv_connector_metadata)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
pass # Launch loads ops in get_finished() after launching model execution
def wait_for_layer_load(self, layer_name: str) -> None:
pass # Always load asynchronously and deferred to get_finished()
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
pass # Always save asynchronously and deferred to get_finished()
def wait_for_save(self) -> None:
pass # All stores are driven by get_finished() and no wait needed
def get_finished(
self,
finished_req_ids: set[str],
) -> tuple[set[str] | None, set[str] | None]:
if self.worker_handler is not None:
return self.worker_handler.get_finished(finished_req_ids)
return None, None
def build_connector_worker_meta(self):
if self.worker_handler is not None:
return self.worker_handler.build_connector_worker_meta()
return None
# --- Scheduler-side methods ---
# NOTE: New API only for SimpleCPUOffloadConnector.
def bind_gpu_block_pool(self, gpu_block_pool: "BlockPool") -> None:
if self.scheduler_manager is not None:
self.scheduler_manager.bind_gpu_block_pool(gpu_block_pool)
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
if self.scheduler_manager is not None:
return self.scheduler_manager.get_num_new_matched_tokens(
request, num_computed_tokens
)
return 0, False
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
) -> None:
if self.scheduler_manager is not None:
self.scheduler_manager.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
if self.scheduler_manager is not None:
return self.scheduler_manager.build_connector_meta(scheduler_output)
return SimpleCPUOffloadMetadata()
def update_connector_output(
self,
connector_output: KVConnectorOutput,
) -> None:
if self.scheduler_manager is not None:
self.scheduler_manager.update_connector_output(connector_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
if self.scheduler_manager is not None:
return self.scheduler_manager.request_finished(request, block_ids)
return False, None
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
if self.scheduler_manager is not None:
return self.scheduler_manager.request_finished_all_groups(
request, block_ids
)
return False, None
# NOTE: New API only for SimpleCPUOffloadConnector.
def has_pending_transfers(self) -> bool:
if self.scheduler_manager is not None:
return self.scheduler_manager.has_pending_stores()
return False
def take_events(self) -> Iterable[KVCacheEvent]:
if self.scheduler_manager is not None:
return self.scheduler_manager.take_events()
return []
def reset_cache(self) -> bool | None:
raise NotImplementedError(
"SimpleCPUOffloadConnector does not support reset_cache(). "
"reset_prefix_cache() requires synchronizing all pending "
"CPU offload transfers before clearing GPU prefix cache blocks, "
"which is not yet implemented."
)

View File

@@ -1674,6 +1674,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_XPU_ENABLE_XPU_GRAPH": lambda: bool(
int(os.getenv("VLLM_XPU_ENABLE_XPU_GRAPH", "0"))
),
# Enable simple KV offload.
"VLLM_USE_SIMPLE_KV_OFFLOAD": lambda: bool(
int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0"))
),
}

View File

@@ -234,6 +234,13 @@ class Scheduler(SchedulerInterface):
hash_block_size=self.block_size,
metrics_collector=self.kv_metrics_collector,
)
# Bind GPU block pool to the KV connector. This must happen after
# kv_cache_manager is constructed so block_pool is available.
if self.connector is not None and hasattr(
self.connector, "bind_gpu_block_pool"
):
self.connector.bind_gpu_block_pool(self.kv_cache_manager.block_pool)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
self.scheduler_reserve_full_isl = (

View File

@@ -281,8 +281,16 @@ class PromptTokenStats:
self.computed += prompt_len - num_cached_tokens
self.external_kv_transfer += num_external_computed_tokens
self.local_cache_hit += (
num_cached_tokens + recomputed - num_external_computed_tokens
# FIXME(yifan): local_cache_hit can go negative after preemption.
# num_cached_tokens is a one-time snapshot from first scheduling and
# is never reset on preemption, while num_external_computed_tokens is
# overwritten on re-scheduling. If CPU offload finds more tokens on
# the second pass than the original total, the subtraction underflows.
# A fundamental fix is to track the first-time num_external_computed_tokens
# as a separate metric rather than reusing num_external_computed_tokens
# for metric directly.
self.local_cache_hit += max(
0, (num_cached_tokens + recomputed - num_external_computed_tokens)
)
self.cached_tokens += num_cached_tokens
self.recomputed_tokens += recomputed

View File

View File

@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""DMA copy backend for GPU<->CPU block transfers."""
from __future__ import annotations
import queue
import threading
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.simple_kv_offload.cuda_mem_ops import (
BatchMemcpyParams,
build_params,
copy_blocks,
)
logger = init_logger(__name__)
class DmaCopyBackend:
"""cuMemcpyBatchAsync copy backend (background thread)."""
def __init__(self) -> None:
self._store_params: BatchMemcpyParams | None = None
self._load_params: BatchMemcpyParams | None = None
self._load_stream: torch.cuda.Stream | None = None
self._store_stream: torch.cuda.Stream | None = None
self._queue: queue.SimpleQueue | None = None
self._thread: threading.Thread | None = None
self._shutdown: bool = False
def init(
self,
gpu_caches: dict[str, torch.Tensor],
cpu_caches: dict[str, torch.Tensor],
device: torch.device,
load_stream: torch.cuda.Stream,
store_stream: torch.cuda.Stream,
) -> None:
self._load_stream = load_stream
self._store_stream = store_stream
self._store_params = build_params(gpu_caches, cpu_caches, store_stream)
self._load_params = build_params(cpu_caches, gpu_caches, load_stream)
self._queue = queue.SimpleQueue()
self._thread = threading.Thread(
target=self._copy_loop,
args=(self._queue, device, load_stream, store_stream),
daemon=True,
)
self._thread.start()
def launch_copy(
self,
src_blocks: list[int],
dst_blocks: list[int],
is_store: bool,
event_idx: int,
events_list: list[tuple[int, torch.Event]],
) -> None:
params = self._store_params if is_store else self._load_params
assert params is not None and self._queue is not None
self._queue.put(
(src_blocks, dst_blocks, params, is_store, event_idx, events_list)
)
def shutdown(self) -> None:
if self._shutdown:
return
self._shutdown = True
if self._queue is not None:
self._queue.put(None)
if self._thread is not None:
self._thread.join(timeout=5.0)
@staticmethod
def _copy_loop(
q: queue.SimpleQueue,
device: torch.device,
load_stream: torch.cuda.Stream,
store_stream: torch.cuda.Stream,
) -> None:
current_platform.set_device(device)
while True:
item = q.get()
if item is None:
return
src_blocks, dst_blocks, params, is_store, event_idx, events_list = item
copy_blocks(src_blocks, dst_blocks, params)
stream = store_stream if is_store else load_stream
event = torch.Event()
event.record(stream)
events_list.append((event_idx, event))

View File

@@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Low-level CUDA memory helpers: pinning and batch DMA transfers."""
import ctypes
from typing import Any, NamedTuple
import numpy as np
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
def pin_tensor(tensor: torch.Tensor) -> None:
"""Pin a CPU tensor via cudaHostRegister.
This bypasses PyTorch's CUDACachingHostAllocator which rounds
every ``pin_memory=True`` allocation up to the next power of 2
(e.g. 100 GB becomes 128 GB).
"""
err = torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.nbytes, 0)
if err.value != 0:
raise RuntimeError(f"cudaHostRegister failed: {err}")
class _CUmemLocation(ctypes.Structure):
_fields_ = [("type", ctypes.c_uint), ("id", ctypes.c_int)]
class _CUmemcpyAttributes(ctypes.Structure):
_fields_ = [
("srcAccessOrder", ctypes.c_uint),
("srcLocHint", _CUmemLocation),
("dstLocHint", _CUmemLocation),
("flags", ctypes.c_uint),
]
_BATCH_MEMCPY_FUNC_TYPE = ctypes.CFUNCTYPE(
ctypes.c_uint, # CUresult
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_void_p,
ctypes.c_void_p,
)
# Resolved lazily on first use.
_batch_memcpy_fn: Any = None
def _resolve_batch_memcpy():
"""Resolve cuMemcpyBatchAsync via cuGetProcAddress (one-time)."""
from cuda.bindings import driver as drv
err, ptr, _ = drv.cuGetProcAddress(b"cuMemcpyBatchAsync", 12080, 0)
if err != drv.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"cuGetProcAddress(cuMemcpyBatchAsync) failed: {err}")
return _BATCH_MEMCPY_FUNC_TYPE(ptr)
class BatchMemcpyParams(NamedTuple):
src_bases: np.ndarray # [num_layers] uint64 — data_ptr per layer
dst_bases: np.ndarray # [num_layers] uint64
bpb: np.ndarray # [num_layers] uint64 — bytes per block
num_layers: int
attrs: _CUmemcpyAttributes
attrs_idx: ctypes.c_size_t
# NOTE: cuMemcpyBatchAsync_v2() removed fail_idx field, but we use
# cuMemcpyBatchAsync() with fail_idx for backward compatibility
fail_idx: ctypes.c_size_t
stream_handle: int # raw cudaStream_t / CUstream
def build_params(
src_caches: dict[str, torch.Tensor],
dst_caches: dict[str, torch.Tensor],
stream: torch.cuda.Stream,
) -> BatchMemcpyParams:
global _batch_memcpy_fn
if _batch_memcpy_fn is None:
_batch_memcpy_fn = _resolve_batch_memcpy()
assert list(src_caches.keys()) == list(dst_caches.keys())
src_tensors = list(src_caches.values())
dst_tensors = list(dst_caches.values())
src_bases, dst_bases, bpb = [], [], []
for s, d in zip(src_tensors, dst_tensors):
s_bpb = s.stride(0) * s.element_size()
assert s_bpb == d.stride(0) * d.element_size()
src_bases.append(s.data_ptr())
dst_bases.append(d.data_ptr())
bpb.append(s_bpb)
# Refer to https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f for details. # noqa: E501
attrs = _CUmemcpyAttributes(srcAccessOrder=3) # ANY
return BatchMemcpyParams(
src_bases=np.array(src_bases, dtype=np.uint64),
dst_bases=np.array(dst_bases, dtype=np.uint64),
bpb=np.array(bpb, dtype=np.uint64),
num_layers=len(src_tensors),
attrs=attrs,
attrs_idx=ctypes.c_size_t(0),
fail_idx=ctypes.c_size_t(0),
stream_handle=stream.cuda_stream,
)
def copy_blocks(
src_block_ids: list[int],
dst_block_ids: list[int],
params: BatchMemcpyParams,
) -> None:
"""Copy blocks via cuMemcpyBatchAsync."""
n = len(src_block_ids)
if n == 0:
return
src_ids = np.array(src_block_ids, dtype=np.uint64)
dst_ids = np.array(dst_block_ids, dtype=np.uint64)
src_all = (
params.src_bases[:, None] + src_ids[None, :] * params.bpb[:, None]
).ravel()
dst_all = (
params.dst_bases[:, None] + dst_ids[None, :] * params.bpb[:, None]
).ravel()
sz_all = np.repeat(params.bpb, n)
total = n * params.num_layers
err = _batch_memcpy_fn(
dst_all.ctypes.data,
src_all.ctypes.data,
sz_all.ctypes.data,
total,
ctypes.addressof(params.attrs),
ctypes.byref(params.attrs_idx),
1,
ctypes.byref(params.fail_idx),
params.stream_handle,
)
if err != 0:
raise RuntimeError(
f"cuMemcpyBatchAsync failed: err={err} failIdx={params.fail_idx.value}"
)

View File

@@ -0,0 +1,739 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Scheduler-side manager for SimpleCPUOffloadConnector."""
import contextlib
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_coordinator import (
KVCacheCoordinator,
get_kv_cache_coordinator,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
MambaSpec,
SlidingWindowSpec,
)
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.simple_kv_offload.metadata import (
SimpleCPUOffloadMetadata,
SimpleCPUOffloadWorkerMetadata,
)
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class TransferMeta:
gpu_block_ids: list[int]
cpu_block_ids: list[int]
@dataclass
class LoadRequestState:
request: "Request"
transfer_meta: TransferMeta
load_event: int | None = None
finished: bool = False
# NOTE: This per-request state is only used in eager mode.
@dataclass
class StoreRequestState:
request: "Request"
# Accumulated block IDs from scheduler_output via yield_req_data.
block_ids: tuple[list[int], ...]
# Per-group cursors tracking how many blocks have been stored/skipped.
num_stored_blocks: list[int]
store_events: set[int] = field(default_factory=set)
finished: bool = False
class SimpleCPUOffloadScheduler:
"""Scheduler-side manager for CPU offloading."""
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_config: "KVCacheConfig | None",
cpu_capacity_bytes: int,
lazy_offload: bool = False,
):
self.vllm_config = vllm_config
self.kv_cache_config = kv_cache_config
self.enable_kv_cache_events = (
vllm_config.kv_events_config is not None
and vllm_config.kv_events_config.enable_kv_cache_events
)
# NOTE: We use the same block size for both GPU and CPU.
self.block_size = vllm_config.cache_config.block_size
# Derive a CPU KVCacheConfig from the GPU config and build a coordinator
assert kv_cache_config is not None
self.cpu_kv_cache_config = self._derive_cpu_config(
kv_cache_config, cpu_capacity_bytes
)
self.num_cpu_blocks = self.cpu_kv_cache_config.num_blocks
# Find the full attention kv group for prefix cache matching.
self.fa_gidx = -1
for g_idx, g in enumerate(self.cpu_kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec):
self.fa_gidx = g_idx
break
assert 0 <= self.fa_gidx < len(self.cpu_kv_cache_config.kv_cache_groups)
logger.info(
"SimpleCPUOffloadScheduler: Allocating %d CPU blocks (%.2f GB, mode=%s)",
self.num_cpu_blocks,
cpu_capacity_bytes / (1024**3),
"lazy" if lazy_offload else "eager",
)
# TODO (yifan): maybe need to enable kv_cache_events and metrics_collector here.
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
assert dcp_world_size == 1 and pcp_world_size == 1
self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator(
kv_cache_config=self.cpu_kv_cache_config,
max_model_len=vllm_config.model_config.max_model_len,
use_eagle=False,
enable_caching=True,
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=self.block_size,
)
self.cpu_block_pool: BlockPool = self.cpu_coordinator.block_pool
# GPU block pool reference - bound after scheduler builds kv_cache_manager
self._gpu_block_pool: BlockPool | None = None
# Load metadata
self._reqs_to_load: dict[str, LoadRequestState] = {}
# Inverse map: load_event_idx -> req_ids. Keyed by load_event_idx because
# the worker reports completions by event index, not request id.
self._load_event_to_reqs: dict[int, list[str]] = {}
# Store metadata
self._lazy_mode = lazy_offload
# Lazy mode: use a cursor to track the last scanned block in the GPU free queue.
self._cursor: KVCacheBlock | None = None
if self._lazy_mode:
self._target_free = self._estimate_lazy_target_blocks(
kv_cache_config,
vllm_config.scheduler_config.max_num_batched_tokens,
)
else:
self._target_free = 0
self._store_event_to_blocks: dict[int, TransferMeta] = {}
# Eager mode only
self._reqs_to_store: dict[str, StoreRequestState] = {}
self._store_event_to_reqs: dict[int, list[str]] = {}
# Event counters
self._load_event_counter: int = 0
self._store_event_counter: int = 0
# For TP/PP: track partial store completions across steps.
# Events must be reported by all world_size workers before considered complete.
self._expected_worker_count = vllm_config.parallel_config.world_size
self._store_event_pending_counts: dict[int, int] = {}
@staticmethod
def _derive_cpu_config(
gpu_config: "KVCacheConfig", cpu_capacity_bytes: int
) -> "KVCacheConfig":
"""Derive a CPU KVCacheConfig from the GPU config.
Same kv_cache_groups, num_blocks scaled by CPU/GPU memory ratio."""
# Import here to avoid potential circular imports
from vllm.v1.kv_cache_interface import KVCacheConfig as KVCacheConfigCls
from vllm.v1.kv_cache_interface import KVCacheTensor
assert len(gpu_config.kv_cache_tensors) > 0
gpu_total_bytes = sum(t.size for t in gpu_config.kv_cache_tensors)
num_gpu_blocks = gpu_config.num_blocks
num_cpu_blocks = max(1, num_gpu_blocks * cpu_capacity_bytes // gpu_total_bytes)
# Create CPU kv_cache_tensors mirroring GPU by scaling size proportionally.
cpu_tensors = [
KVCacheTensor(
size=t.size // num_gpu_blocks * num_cpu_blocks,
shared_by=list(t.shared_by),
)
for t in gpu_config.kv_cache_tensors
]
return KVCacheConfigCls(
num_blocks=num_cpu_blocks,
kv_cache_tensors=cpu_tensors,
kv_cache_groups=gpu_config.kv_cache_groups,
)
@staticmethod
def _estimate_lazy_target_blocks(
kv_cache_config: "KVCacheConfig", max_num_batched_tokens: int
) -> int:
"""GPU blocks to keep available (free/offloaded) per step in lazy mode."""
WATERMARK_RATIO = 1.0 # Reserve larger space to avoid running out of GPU blocks
target = 0
for g in kv_cache_config.kv_cache_groups:
spec = g.kv_cache_spec
if isinstance(spec, MambaSpec):
target += 2
elif isinstance(spec, SlidingWindowSpec):
target += cdiv(spec.sliding_window, spec.block_size) + 1
else:
target += cdiv(max_num_batched_tokens, spec.block_size)
return int(target * (1 + WATERMARK_RATIO))
def bind_gpu_block_pool(self, gpu_block_pool: BlockPool) -> None:
"""Bind GPU block pool so that we can touch blocks during stores.
Called by Scheduler after kv_cache_manager is ready."""
self._gpu_block_pool = gpu_block_pool
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
"""Return (num_new_tokens, is_async) from consecutive CPU cache hits."""
skipped = num_computed_tokens // self.block_size
remaining_hashes = request.block_hashes[skipped:]
if not remaining_hashes:
return 0, False
# Must recompute at least the last token, matching the logic in
# kv_cache_manager.get_computed_blocks().
max_hit_len = request.num_tokens - 1 - num_computed_tokens
if max_hit_len <= 0:
return 0, False
_, hit_length = self.cpu_coordinator.find_longest_cache_hit(
remaining_hashes, max_hit_len
)
if hit_length > 0:
return hit_length, True
return 0, False
# TODO(yifan): this API now only matches the suffix part of the prefix cache. A more
# general API should scan blocks in both GPU and CPU block pool in a single pass.
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
) -> None:
req_id = request.request_id
block_ids_by_group = blocks.get_block_ids()
num_groups = len(block_ids_by_group)
# Store tracking (eager mode only). Register the request;
# block IDs are accumulated from scheduler_output in
# _prepare_eager_store_specs via yield_req_data.
if not self._lazy_mode and req_id not in self._reqs_to_store:
self._reqs_to_store[req_id] = StoreRequestState(
request=request,
block_ids=tuple([] for _ in range(num_groups)),
num_stored_blocks=[0] * num_groups,
)
if num_external_tokens == 0:
return
num_blocks_to_load = num_external_tokens // self.block_size
assert num_blocks_to_load > 0
skipped = sum(blk.block_hash is not None for blk in blocks.blocks[self.fa_gidx])
num_computed_tokens = skipped * self.block_size
hashes_to_load = request.block_hashes[skipped : skipped + num_blocks_to_load]
# Find CPU cached blocks across all groups.
max_hit_len = len(hashes_to_load) * self.block_size
cpu_hit_blocks, hit_length = self.cpu_coordinator.find_longest_cache_hit(
hashes_to_load, max_hit_len
)
assert hit_length == num_external_tokens, (
f"Expected {num_external_tokens} hit tokens, got {hit_length}"
)
# Build transfer pairs across all groups.
total_computed_tokens = num_computed_tokens + num_external_tokens
kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups
gpu_block_ids: list[int] = []
cpu_block_ids: list[int] = []
cpu_blocks_to_touch: list[KVCacheBlock] = []
for g in range(num_groups):
cpu_blocks_g = cpu_hit_blocks[g]
n_ext_g = len(cpu_blocks_g)
if n_ext_g == 0:
continue
# Number of blocks in the computed range for this group.
g_block_size = kv_cache_groups[g].kv_cache_spec.block_size
n_computed_g = cdiv(total_computed_tokens, g_block_size)
# Back-trace: ext blocks sit at the tail of the computed range.
gpu_ext_start = n_computed_g - n_ext_g
group_gpu_ids = block_ids_by_group[g]
for i, cpu_blk in enumerate(cpu_blocks_g):
# Skip null blocks (e.g. sliding window or mamba padding).
if cpu_blk.is_null:
continue
gpu_block_ids.append(group_gpu_ids[gpu_ext_start + i])
cpu_block_ids.append(cpu_blk.block_id)
cpu_blocks_to_touch.append(cpu_blk)
# Touch CPU blocks to prevent eviction during async load.
self.cpu_block_pool.touch(cpu_blocks_to_touch)
# Touch GPU blocks to prevent freeing during async load
assert self._gpu_block_pool is not None
self._gpu_block_pool.touch(
[self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids]
)
assert self._reqs_to_load.get(req_id) is None
self._reqs_to_load[req_id] = LoadRequestState(
request=request, transfer_meta=TransferMeta(gpu_block_ids, cpu_block_ids)
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> SimpleCPUOffloadMetadata:
# --- Stores ---
store_event = -1
store_gpu, store_cpu, store_req_ids = self.prepare_store_specs(scheduler_output)
if store_gpu:
store_event = self._store_event_counter
self._store_event_counter += 1
self._store_event_to_blocks[store_event] = TransferMeta(
store_gpu, store_cpu
)
if store_req_ids: # For eager mode only, track req->blocks mapping
self._store_event_to_reqs[store_event] = store_req_ids
for req_id in store_req_ids:
store_state = self._reqs_to_store.get(req_id)
if store_state is not None:
store_state.store_events.add(store_event)
# --- Loads ---
load_event = -1
load_gpu: list[int] = []
load_cpu: list[int] = []
load_req_ids: list[str] = []
for req_id, load_state in self._reqs_to_load.items():
if load_state.load_event is not None:
continue
assert load_state.transfer_meta is not None
load_gpu.extend(load_state.transfer_meta.gpu_block_ids)
load_cpu.extend(load_state.transfer_meta.cpu_block_ids)
load_req_ids.append(req_id)
if load_req_ids:
load_event = self._load_event_counter
self._load_event_counter += 1
for req_id in load_req_ids:
self._reqs_to_load[req_id].load_event = load_event
self._load_event_to_reqs[load_event] = load_req_ids
result = SimpleCPUOffloadMetadata(
load_event=load_event,
load_gpu_blocks=load_gpu,
load_cpu_blocks=load_cpu,
load_event_to_reqs=self._load_event_to_reqs,
store_event=store_event,
store_gpu_blocks=store_gpu,
store_cpu_blocks=store_cpu,
need_flush=bool(scheduler_output.preempted_req_ids),
)
return result
def prepare_store_specs(
self, scheduler_output: SchedulerOutput
) -> tuple[list[int], list[int], list[str]]:
"""Prepare store specs for the store event."""
if self._lazy_mode:
return self._prepare_lazy_store_specs()
else:
return self._prepare_eager_store_specs(scheduler_output)
def _prepare_lazy_store_specs(
self,
) -> tuple[list[int], list[int], list[str]]:
"""Single-pass cursor walk: offload cached GPU blocks near eviction.
Walks the GPU free queue from the cursor, counting blocks that are
free-or-offloaded (safe for the allocator to evict). Stops when
target_free blocks are covered or CPU capacity is reached.
"""
gpu_pool = self._gpu_block_pool
if gpu_pool is None or self._target_free <= 0:
return [], [], []
free_queue = gpu_pool.free_block_queue
cpu_pool = self.cpu_block_pool
num_cpu_free = cpu_pool.get_num_free_blocks()
# Validate cursor: stale if block was removed from free queue.
if self._cursor is not None and self._cursor.ref_cnt > 0:
self._cursor = None
# Determine start node.
if self._cursor is None:
node = free_queue.fake_free_list_head.next_free_block
else:
node = self._cursor.next_free_block
tail = free_queue.fake_free_list_tail
gpu_ids: list[int] = []
block_hashes: list[bytes] = []
covered = 0
last_visited = self._cursor
while (
node is not None
and node is not tail
and covered < self._target_free
and len(gpu_ids) < num_cpu_free
):
last_visited = node
bhash = node.block_hash
if (
bhash is not None
and not node.is_null
and cpu_pool.cached_block_hash_to_block.get_one_block(bhash) is None
):
gpu_ids.append(node.block_id)
block_hashes.append(bhash)
covered += 1
node = node.next_free_block
self._cursor = last_visited
# Batch-allocate CPU blocks and stamp hashes.
if gpu_ids:
cpu_blocks = cpu_pool.get_new_blocks(len(gpu_ids))
cpu_ids = [blk.block_id for blk in cpu_blocks]
for cpu_blk, bhash in zip(cpu_blocks, block_hashes): # type: ignore[assignment]
cpu_blk._block_hash = bhash # type: ignore[assignment]
# Touch GPU blocks to prevent eviction during async copy.
gpu_pool.touch([gpu_pool.blocks[bid] for bid in gpu_ids])
else:
cpu_ids = []
return gpu_ids, cpu_ids, []
def _prepare_eager_store_specs(
self, scheduler_output: SchedulerOutput
) -> tuple[list[int], list[int], list[str]]:
"""Identify newly computed blocks to offload from scheduler requests.
Only considers blocks whose KV data has been **confirmed computed** by
the GPU. This means blocks from the current step are NOT stored until the
next step. If a request finishes in the same step as its last full block,
that block may be missed. (TODO: flush on finish.)
Returns:
(gpu_block_ids, cpu_block_ids, req_ids) for the store event.
"""
merged_gpu_block_ids: list[int] = []
merged_cpu_block_ids: list[int] = []
req_ids: list[str] = []
gpu_block_pool = self._gpu_block_pool
if gpu_block_pool is None:
return [], [], []
cpu_block_pool = self.cpu_block_pool
num_free = cpu_block_pool.get_num_free_blocks()
kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups
num_groups = len(kv_cache_groups)
gpu_blocks_this_step: set[int] = set()
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
state = self._reqs_to_store.get(req_id)
if state is None or state.finished:
continue
# Accumulate new block IDs.
if preempted:
state.block_ids = tuple([] for _ in range(num_groups))
state.num_stored_blocks = [0] * num_groups
if new_block_id_groups:
for g in range(min(num_groups, len(new_block_id_groups))):
if new_block_id_groups[g] is not None:
state.block_ids[g].extend(new_block_id_groups[g])
num_new_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
if num_new_tokens == 0:
continue
block_ids_by_group = state.block_ids
if not block_ids_by_group:
continue
# --- Phase 1: Scan blocks, classify as cached vs to-store ---
gpu_block_ids: list[int] = []
block_hashes_to_store: list[bytes] = []
advanced_per_group: list[int] = [0] * num_groups
out_of_space = False
# Confirmed tokens: KV data written and visible to all streams.
req = state.request
confirmed_tokens = req.num_computed_tokens - req.num_output_placeholders
for g in range(num_groups):
# FIXME (yifan): handle CPU cache eviction, where
# num_stored_blocks can be stale and omit evicted blocks in
# the middle of the request.
already_stored_g = state.num_stored_blocks[g]
group_gpu_ids = block_ids_by_group[g]
# Cap to blocks with confirmed KV data.
g_block_size = kv_cache_groups[g].kv_cache_spec.block_size
ready_blocks_g = confirmed_tokens // g_block_size
scannable = group_gpu_ids[already_stored_g:ready_blocks_g]
for gpu_block_id in scannable:
gpu_block = gpu_block_pool.blocks[gpu_block_id]
if gpu_block.is_null:
advanced_per_group[g] += 1
continue
bhash_with_group = gpu_block.block_hash
if bhash_with_group is None:
break
# Check if this group's data is already scheduled for store
# in this step or already cached in CPU.
if (
gpu_block_id in gpu_blocks_this_step
or cpu_block_pool.cached_block_hash_to_block.get_one_block(
bhash_with_group
)
is not None
):
advanced_per_group[g] += 1
continue
if num_free <= 0:
out_of_space = True
break
num_free -= 1
gpu_block_ids.append(gpu_block_id)
block_hashes_to_store.append(bhash_with_group)
advanced_per_group[g] += 1
if out_of_space:
break
# --- Phase 2: Batch allocate CPU blocks and stamp hashes ---
n_to_alloc = len(gpu_block_ids)
if n_to_alloc > 0:
cpu_blocks_alloc = cpu_block_pool.get_new_blocks(n_to_alloc)
cpu_block_ids = [blk.block_id for blk in cpu_blocks_alloc]
for cpu_blk, bhash in zip(cpu_blocks_alloc, block_hashes_to_store):
cpu_blk._block_hash = bhash # type: ignore[assignment]
else:
cpu_block_ids = []
if cpu_block_ids:
req_ids.append(req_id)
merged_gpu_block_ids.extend(gpu_block_ids)
merged_cpu_block_ids.extend(cpu_block_ids)
gpu_blocks_this_step.update(gpu_block_ids)
# Touch GPU blocks to prevent freeing during async copy
gpu_block_pool.touch(
[gpu_block_pool.blocks[bid] for bid in gpu_block_ids]
)
logger.debug(
"Request %s: Scheduling store of %d blocks to CPU (%d groups)",
req_id,
len(cpu_block_ids),
num_groups,
)
# Advance per-group cursors (includes cached hits + newly stored)
for g in range(num_groups):
state.num_stored_blocks[g] += advanced_per_group[g]
return merged_gpu_block_ids, merged_cpu_block_ids, req_ids
def update_connector_output(self, connector_output: KVConnectorOutput) -> None:
"""Handle async transfer completions from worker.
Load completions arrive via finished_recving (real req_ids).
Store completions arrive via kv_connector_worker_meta as
per-event worker counts. We accumulate across steps and process
a store event only when all workers have reported completion.
"""
# --- Load completions ---
for req_id in list(connector_output.finished_recving or []):
self._cleanup_load_request(req_id)
# --- Store completions ---
meta = connector_output.kv_connector_worker_meta
if not isinstance(meta, SimpleCPUOffloadWorkerMetadata):
return
for event_idx, count in meta.completed_store_events.items():
total = self._store_event_pending_counts.get(event_idx, 0) + count
if total >= self._expected_worker_count:
self._store_event_pending_counts.pop(event_idx, None)
self._process_store_event(event_idx)
else:
self._store_event_pending_counts[event_idx] = total
def _process_store_event(self, event_idx: int) -> None:
"""Process a fully-completed store event."""
transfer = self._store_event_to_blocks.pop(event_idx)
self._process_store_completion(transfer.gpu_block_ids, transfer.cpu_block_ids)
logger.debug(
"Store event %d completed: cached %d blocks to CPU",
event_idx,
len(transfer.cpu_block_ids),
)
# Eager only: update per-req state
if not self._lazy_mode:
for req_id in self._store_event_to_reqs.pop(event_idx, []):
state = self._reqs_to_store.get(req_id)
if state is None:
continue
state.store_events.discard(event_idx)
if state.finished and not state.store_events:
self._cleanup_store_request(req_id)
def _process_store_completion(
self, gpu_block_ids: list[int], cpu_block_ids: list[int]
) -> None:
"""Cache CPU blocks per-group and release GPU refs.
Block hashes were stamped on CPU blocks at allocation time (in
``_prepare_*_store_specs``). Here we just register them in the
cache map so they become discoverable by the load path.
"""
assert len(cpu_block_ids) == len(gpu_block_ids)
cpu_blocks = [self.cpu_block_pool.blocks[bid] for bid in cpu_block_ids]
for cpu_block in cpu_blocks:
bhash = cpu_block.block_hash
assert bhash is not None
self.cpu_block_pool.cached_block_hash_to_block.insert(bhash, cpu_block)
# Free CPU and GPU blocks' ref counts to turn them into prefix cache
self.cpu_block_pool.free_blocks(cpu_blocks)
assert self._gpu_block_pool is not None
self._gpu_block_pool.free_blocks(
self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids
)
def has_pending_stores(self) -> bool:
"""Return True if there are in-flight store transfers."""
return bool(self._store_event_to_blocks)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""Always returns (False, None). GPU blocks are protected by ref_cnt,
so the scheduler can free blocks immediately."""
req_id = request.request_id
# Handle load: defer cleanup if load is in-flight
load_state = self._reqs_to_load.get(req_id)
if load_state is not None:
if load_state.load_event is not None:
load_state.finished = True # Defer: load in-flight
else:
self._cleanup_load_request(req_id)
# Handle store (eager mode only): defer cleanup if stores in-flight
if not self._lazy_mode:
store_state = self._reqs_to_store.get(req_id)
if store_state is not None:
if store_state.store_events:
store_state.finished = True # Defer: stores in-flight
else:
self._cleanup_store_request(req_id)
return False, None
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
return self.request_finished(request, block_ids=[])
def _cleanup_load_request(self, req_id: str) -> None:
"""Release all load resources for a request.
Shared between request_finished() and update_connector_output() paths.
Removes the request from _reqs_to_load, cleans up event mappings,
and frees CPU/GPU touch refs.
"""
state = self._reqs_to_load.pop(req_id, None)
if state is None:
return
# Remove from load event mapping (only this req, not whole event)
if state.load_event is not None:
reqs = self._load_event_to_reqs.get(state.load_event)
if reqs is not None:
with contextlib.suppress(ValueError):
reqs.remove(req_id)
if not reqs:
self._load_event_to_reqs.pop(state.load_event, None)
if state.transfer_meta is not None:
# Free CPU touch refs
self.cpu_block_pool.free_blocks(
self.cpu_block_pool.blocks[bid]
for bid in state.transfer_meta.cpu_block_ids
)
# Free GPU touch refs
assert self._gpu_block_pool is not None
self._gpu_block_pool.free_blocks(
self._gpu_block_pool.blocks[bid]
for bid in state.transfer_meta.gpu_block_ids
)
def _cleanup_store_request(self, req_id: str) -> None:
"""Release store metadata for a request.
Metadata-only cleanup but no block freeing. Job completion handles
block caching and GPU ref freeing via _process_store_completion().
"""
state = self._reqs_to_store.pop(req_id, None)
if state is None:
return
for event_idx in list(state.store_events):
if (reqs := self._store_event_to_reqs.get(event_idx)) is not None:
with contextlib.suppress(ValueError):
reqs.remove(req_id)
if not reqs:
self._store_event_to_reqs.pop(event_idx, None)
state.store_events.clear()
def take_events(self) -> Iterable[KVCacheEvent]:
return self.cpu_block_pool.take_events()

View File

@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Metadata for SimpleCPUOffloadConnector."""
from dataclasses import dataclass, field
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorWorkerMetadata,
)
INVALID_JOB_ID = -1
@dataclass
class SimpleCPUOffloadMetadata(KVConnectorMetadata):
"""
Metadata passed from scheduler to worker for CPU offload operations.
The worker receives flat block lists keyed by a monotonic event_idx.
Job->req_id translation is handled by the scheduler-side manager
(via inverse maps), so the worker never knows about request identities.
"""
# Load event per step. INVALID_JOB_ID means no blocks to load this step.
load_event: int = INVALID_JOB_ID
load_gpu_blocks: list[int] = field(default_factory=list)
load_cpu_blocks: list[int] = field(default_factory=list)
# Reverse map: load_event->req_ids, for tracking requests with finished load events
load_event_to_reqs: dict[int, list[str]] = field(default_factory=dict)
# Store event per step. INVALID_JOB_ID means no blocks to store this step.
store_event: int = INVALID_JOB_ID
store_gpu_blocks: list[int] = field(default_factory=list)
store_cpu_blocks: list[int] = field(default_factory=list)
# Whether any requests were preempted this step and need flush pending transfers.
need_flush: bool = False
@dataclass
class SimpleCPUOffloadWorkerMetadata(KVConnectorWorkerMetadata):
"""Worker -> Scheduler metadata for completed store events.
Each worker reports {event_idx: 1} for newly completed stores.
``aggregate()`` sums counts across workers within a step.
The scheduler-side manager accumulates across steps and processes
a store completion only when count reaches ``world_size``.
"""
completed_store_events: dict[int, int]
def aggregate(
self, other: "KVConnectorWorkerMetadata"
) -> "KVConnectorWorkerMetadata":
assert isinstance(other, SimpleCPUOffloadWorkerMetadata)
merged = dict(self.completed_store_events)
for k, v in other.completed_store_events.items():
merged[k] = merged.get(k, 0) + v
return SimpleCPUOffloadWorkerMetadata(completed_store_events=merged)

View File

@@ -0,0 +1,305 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Worker-side handler for SimpleCPUOffloadConnector."""
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.simple_kv_offload.copy_backend import DmaCopyBackend
from vllm.v1.simple_kv_offload.cuda_mem_ops import pin_tensor
from vllm.v1.simple_kv_offload.metadata import (
SimpleCPUOffloadMetadata,
SimpleCPUOffloadWorkerMetadata,
)
if TYPE_CHECKING:
from vllm.v1.kv_cache_interface import KVCacheConfig
logger = init_logger(__name__)
class SimpleCPUOffloadWorker:
"""Worker-side handler for CPU offloading transfers."""
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_config: "KVCacheConfig | None",
cpu_capacity_bytes: int,
):
self.vllm_config = vllm_config
self.kv_cache_config = kv_cache_config
self.cpu_capacity_bytes = cpu_capacity_bytes
self.gpu_kv_caches: dict[str, torch.Tensor] | None = None
self.cpu_kv_caches: dict[str, torch.Tensor] | None = None
self.device: torch.device | None = None
self.num_cpu_blocks: int = 0
# CUDA streams for the async transfers
self.load_stream: torch.cuda.Stream | None = None
self.store_stream: torch.cuda.Stream | None = None
self._backend = DmaCopyBackend()
# Ordered (event_idx, Event). Events pre-allocated on main thread.
self._load_events: list[tuple[int, torch.Event]] = []
self._store_events: list[tuple[int, torch.Event]] = []
# High-water marks: highest event_idx completed per stream.
# When the event list is empty, the hwm covers all prior events.
self._load_hwm: int = -1
self._store_hwm: int = -1
# Metadata for the current step
self._connector_metadata: SimpleCPUOffloadMetadata | None = None
# Pending event index sets, populated in bind_connector_metadata
self._pending_load_event_indices: set[int] = set()
self._pending_store_event_indices: set[int] = set()
# Completed store events to report via build_connector_worker_meta
self._completed_store_events: dict[int, int] = {}
def register_kv_caches(
self,
kv_caches: dict[str, torch.Tensor],
) -> None:
"""Register GPU KV caches and allocate pinned CPU tensors.
The worker will infer the underlying raw storage from the kv_caches.
Args:
kv_caches: Per-layer GPU KV caches. Values are either a single
tensor (attention layers) or a list of tensors (Mamba layers
in hybrid models). All values are included for offloading
by resolving to their underlying raw storage.
"""
if not kv_caches:
logger.warning("No KV caches to offload.")
return
# Resolve each entry to a representative tensor for storage
# deduplication. For attention layers the value is already a tensor;
# for Mamba layers it is a list of tensors that all share the same
# underlying raw storage, so we take the first one.
def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
assert isinstance(v, torch.Tensor | list)
return v if isinstance(v, torch.Tensor) else v[0]
any_tensor = _repr_tensor(next(iter(kv_caches.values())))
self.device = any_tensor.device
assert self.kv_cache_config is not None
num_blocks = self.kv_cache_config.num_blocks
# Deduplicate: multiple layers may share the same backing storage.
seen_ptrs: dict[int, tuple[str, torch.Tensor]] = {}
for name, value in kv_caches.items():
tensor = _repr_tensor(value)
ptr = tensor.untyped_storage().data_ptr()
if ptr not in seen_ptrs:
seen_ptrs[ptr] = (name, tensor)
# Build [num_blocks, block_bytes] int8 views from each unique
# storage so that stride(0) gives block_bytes for the copy op.
#
# The physical layout varies across attention backends:
# FlashAttn/ROCm: (2, num_blocks, ...) -> K/V outermost, 2 segments
# FlashInfer/MLA: (num_blocks, ...) -> blocks outermost, 1 segment
# We derive page_size_bytes = storage.nbytes() // num_blocks, then
# classify dims: any dim whose byte-stride exceeds page_size_bytes
# must be an outer segment dim (e.g. the K/V dim of size 2). A less
# hacky way is to update the interface with the layout.
unique_gpu_caches: dict[str, torch.Tensor] = {}
for name, tensor in seen_ptrs.values():
storage = tensor.untyped_storage()
raw = torch.empty(0, dtype=torch.int8, device=self.device).set_(
storage, 0, (storage.nbytes(),)
)
el = tensor.element_size()
page_size_bytes = storage.nbytes() // num_blocks
outer_dims = [
d for d in range(tensor.ndim) if tensor.stride(d) * el > page_size_bytes
]
if not outer_dims:
unique_gpu_caches[name] = raw.view(num_blocks, -1)
else:
seg_stride = tensor.stride(outer_dims[0]) * el
for idx in range(tensor.shape[outer_dims[0]]):
offset = idx * seg_stride
chunk = raw[offset : offset + seg_stride]
unique_gpu_caches[f"{name}.{idx}"] = chunk.view(num_blocks, -1)
# Compute per-tensor bytes_per_block. Tensors may have different
# page_size_bytes (e.g., UniformTypeKVCacheSpecs with varying head_size).
per_tensor_bpb = [
t.stride(0) * t.element_size() for t in unique_gpu_caches.values()
]
total_bytes_per_block = sum(per_tensor_bpb)
self.num_cpu_blocks = max(1, self.cpu_capacity_bytes // total_bytes_per_block)
logger.info(
"SimpleCPUOffloadWorker: %d unique GPU KV tensors, "
"allocating %d CPU blocks (%.2f GB)",
len(unique_gpu_caches),
self.num_cpu_blocks,
(self.num_cpu_blocks * total_bytes_per_block) / (1024**3),
)
pin_memory = is_pin_memory_available()
if not pin_memory:
logger.warning(
"Pinned memory not available. CPU offload performance may be degraded."
)
self.gpu_kv_caches = unique_gpu_caches
self.cpu_kv_caches = {}
for name, gpu_tensor in unique_gpu_caches.items():
cpu_shape = (self.num_cpu_blocks,) + gpu_tensor.shape[1:]
# Allocate non-pinned first, then pin via cudaHostRegister to
# bypass PyTorch's CUDACachingHostAllocator which rounds up to
# the next power of 2 (e.g. 100 GB -> 128 GB).
tensor = torch.zeros(cpu_shape, dtype=gpu_tensor.dtype, device="cpu")
if pin_memory:
pin_tensor(tensor)
self.cpu_kv_caches[name] = tensor
# Use lowest priority so KV cache I/O yields to compute streams.
low_pri, _ = torch.cuda.Stream.priority_range()
self.load_stream = torch.cuda.Stream(priority=low_pri)
self.store_stream = torch.cuda.Stream(priority=low_pri)
# Initialize copy backend with caches and streams.
self._backend.init(
self.gpu_kv_caches,
self.cpu_kv_caches,
self.device,
self.load_stream,
self.store_stream,
)
def bind_connector_metadata(self, metadata: SimpleCPUOffloadMetadata) -> None:
self._connector_metadata = metadata
if metadata.load_event >= 0:
self._pending_load_event_indices.add(metadata.load_event)
if metadata.store_event >= 0:
self._pending_store_event_indices.add(metadata.store_event)
def clear_connector_metadata(self) -> None:
self._connector_metadata = None
def start_load_kv(self) -> None:
# NOTE: we defer launching both load and store to get_finished(),
# which runs after model execution. This hides the CPU-side
# block copy op overhead (~5ms) behind GPU compute.
pass
def wait_for_save(self) -> None:
pass
def get_finished(
self,
finished_req_ids: set[str],
) -> tuple[set[str] | None, set[str] | None]:
"""Submit transfers and report completed events to the scheduler.
Called after model execution. The manager only schedules stores for
blocks whose KV data is confirmed computed, so we launch both loads
and stores immediately — no deferral or cross-stream sync needed.
Returns:
tuple of (finished_sending, finished_recving).
- finished_sending: always None (stores use worker metadata).
- finished_recving: req_ids whose loads have completed.
"""
# (1) Submit transfers
metadata = self._connector_metadata
if metadata is not None:
# Launch loads (CPU->GPU).
if metadata.load_cpu_blocks:
self._backend.launch_copy(
metadata.load_cpu_blocks,
metadata.load_gpu_blocks,
is_store=False,
event_idx=metadata.load_event,
events_list=self._load_events,
)
# Launch stores (GPU->CPU).
if metadata.store_gpu_blocks:
self._backend.launch_copy(
metadata.store_gpu_blocks,
metadata.store_cpu_blocks,
is_store=True,
event_idx=metadata.store_event,
events_list=self._store_events,
)
# (2) Track completed transfer events
finished_recving: set[str] = set()
if self._pending_load_event_indices:
load_wm = self._poll_stream_events(is_store=False)
for j in [j for j in self._pending_load_event_indices if j <= load_wm]:
self._pending_load_event_indices.discard(j)
req_ids = (
metadata.load_event_to_reqs.get(j) if metadata is not None else None
)
if req_ids:
finished_recving.update(req_ids)
if self._pending_store_event_indices:
store_wm = self._poll_stream_events(is_store=True)
for j in [j for j in self._pending_store_event_indices if j <= store_wm]:
self._pending_store_event_indices.discard(j)
self._completed_store_events[j] = 1
return None, finished_recving or None
def build_connector_worker_meta(self) -> SimpleCPUOffloadWorkerMetadata | None:
"""Return completed store events since the last call."""
if not self._completed_store_events:
return None
meta = SimpleCPUOffloadWorkerMetadata(
completed_store_events=self._completed_store_events,
)
self._completed_store_events = {}
return meta
def handle_preemptions(
self, kv_connector_metadata: SimpleCPUOffloadMetadata
) -> None:
"""Sync all in-flight transfers before preempted blocks are reused."""
if not kv_connector_metadata.need_flush:
return
self._flush_and_sync_all()
def _flush_and_sync_all(self) -> None:
"""Synchronize all in-flight transfer events."""
for event_idx, event in self._load_events:
event.synchronize()
self._load_hwm = event_idx
self._load_events.clear()
for event_idx, event in self._store_events:
event.synchronize()
self._store_hwm = event_idx
self._store_events.clear()
def _poll_stream_events(self, is_store: bool) -> int:
"""Non-blocking poll for completed events and return the high-water mark."""
events = self._store_events if is_store else self._load_events
hwm = self._store_hwm if is_store else self._load_hwm
while events:
event_idx, event = events[0]
if not event.query():
break
hwm = event_idx
events.pop(0)
if is_store:
self._store_hwm = hwm
else:
self._load_hwm = hwm
return hwm