diff --git a/tests/v1/simple_kv_offload/__init__.py b/tests/v1/simple_kv_offload/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/v1/simple_kv_offload/test_integration.py b/tests/v1/simple_kv_offload/test_integration.py new file mode 100644 index 000000000..29399516b --- /dev/null +++ b/tests/v1/simple_kv_offload/test_integration.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Integration tests for SimpleCPUOffloadConnector with real models.""" + +import time + +import pytest + +from vllm import LLM, SamplingParams, TokensPrompt +from vllm.config import KVTransferConfig +from vllm.platforms import current_platform + +if not current_platform.is_cuda(): + pytest.skip("Requires CUDA", allow_module_level=True) + +# Small models for default CI / local runs (accuracy only). +SMALL_MODELS = [ + "meta-llama/Llama-3.2-1B-Instruct", + "google/gemma-3-1b-it", +] + +# Large models for optional perf runs only (slow to load and execute). +PERF_MODELS = [ + "meta-llama/Llama-3.1-8B", + "openai/gpt-oss-20b", +] + + +def _make_llm(model: str, lazy: bool, cpu_bytes_to_use: int) -> LLM: + kv_transfer_config = KVTransferConfig( + kv_connector="SimpleCPUOffloadConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "cpu_bytes_to_use": cpu_bytes_to_use, + "lazy_offload": lazy, + }, + ) + return LLM( + model=model, + kv_cache_memory_bytes=40 << 30, # 40 GiB + disable_hybrid_kv_cache_manager=False, + enable_prefix_caching=True, + kv_transfer_config=kv_transfer_config, + ) + + +def _flush_gpu_cache(llm: LLM, sampling_params: SamplingParams, seed: int = 0): + """Generate enough filler requests to allocate the entire GPU KV cache. + + This pushes all prior blocks through the free queue so that the lazy + cursor offloads them to CPU before they are evicted. + """ + cache_config = llm.llm_engine.vllm_config.cache_config + num_gpu_blocks = cache_config.num_gpu_blocks + block_size = cache_config.block_size + # Use 1.2x GPU capacity to give the lazy cursor enough scheduling steps + # to walk past all target blocks near the tail of the free queue. + total_tokens_needed = int(num_gpu_blocks * block_size * 1.5) + + # Use token-id prompts so each filler is unique (no prefix sharing). + # Split into multiple requests to stay under max_model_len. + max_tokens_per_req = 4096 + num_fillers = (total_tokens_needed + max_tokens_per_req - 1) // max_tokens_per_req + batch_size = 10 + for i in range(0, num_fillers, batch_size): + batch_end = min(i + batch_size, num_fillers) + filler_prompts = [] + for j in range(i, batch_end): + ids = [seed * num_fillers + j + 1] * max_tokens_per_req + filler_prompts.append(TokensPrompt(prompt_token_ids=ids)) + llm.generate(filler_prompts, sampling_params, use_tqdm=False) + + +def _accuracy_test(llm: LLM, lazy: bool = False): + """Verify that CPU-loaded KV produces correct output.""" + sampling_params = SamplingParams(max_tokens=1, temperature=0) + prompt = "hi " * 2000 + "Let's count to ten. One, two, three, " + + # Cold run — populate GPU cache and trigger CPU offload + cold_output = llm.generate(prompt, sampling_params, use_tqdm=False)[0] + + # CPU hit runs + test_count = 10 + success_count = 0 + expected = cold_output.outputs[0].text + for i in range(test_count): + if lazy: + _flush_gpu_cache(llm, sampling_params, seed=i) + time.sleep(2) # let engine core drain pending transfers + + # Reset GPU prefix cache so next run must load from CPU + if not llm.reset_prefix_cache(): + print(f"GPU prefix cache reset failed for iteration {i}") + + output = llm.generate(prompt, sampling_params, use_tqdm=False)[0] + if output.outputs[0].text == expected: + success_count += 1 + + assert success_count >= 0.5 * test_count, ( + f"Accuracy too low: {success_count}/{test_count} matched '{expected}'" + ) + + +def _latency_test(llm: LLM, lazy: bool = False): + """Verify CPU cache hit is faster than cold compute.""" + sampling_params = SamplingParams(max_tokens=1, seed=42) + prompt_token_ids = [0] * 10001 + + num_times_cpu_better = 0 + num_tests = 10 + for i in range(num_tests): + prompt_token_ids[0] = i + prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)] + + # Cold + time.sleep(2) # let engine core drain pending transfers + if not llm.reset_prefix_cache(): + print(f"GPU prefix cache reset failed for iteration {i}") + start = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cold_time = time.time() - start + + if lazy: + _flush_gpu_cache(llm, sampling_params, seed=i) + else: + # Eager mode: GPU hit ensures store completion is processed. + llm.generate(prompts, sampling_params, use_tqdm=False) + + time.sleep(2) # let engine core drain pending transfers + if not llm.reset_prefix_cache(): + print(f"GPU prefix cache reset failed for iteration {i}") + + # CPU hit + start = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cpu_time = time.time() - start + + if cpu_time < cold_time: + num_times_cpu_better += 1 + + assert num_times_cpu_better >= 0.8 * num_tests, ( + f"CPU hit only faster {num_times_cpu_better}/{num_tests} times" + ) + + +@pytest.mark.optional +@pytest.mark.slow_test +@pytest.mark.parametrize("model", SMALL_MODELS) +def test_simple_cpu_offload_accuracy(model: str): + """Store to CPU, reset GPU, load from CPU; verify output matches baseline.""" + llm = _make_llm(model, False, 1 << 30) # 1GB + try: + _accuracy_test(llm, lazy=False) + finally: + del llm + + +@pytest.mark.optional +@pytest.mark.slow_test +@pytest.mark.parametrize("model", PERF_MODELS) +def test_simple_cpu_offload_perf_latency(model: str): + """CPU KV hit should beat cold prefill on long context (large models only).""" + llm = _make_llm(model, False, 10 << 30) # 10GB + try: + _latency_test(llm, lazy=False) + finally: + del llm + + +@pytest.mark.optional +@pytest.mark.slow_test +@pytest.mark.parametrize("model", SMALL_MODELS) +def test_simple_cpu_offload_accuracy_lazy(model: str): + """Lazy mode: flush GPU cache to trigger CPU offload, then verify hit.""" + # CPU must be larger than GPU KV cache to avoid evicting offloaded blocks. + llm = _make_llm(model, True, 80 << 30) # 80GB + try: + _accuracy_test(llm, lazy=True) + finally: + del llm + + +@pytest.mark.optional +@pytest.mark.slow_test +@pytest.mark.parametrize("model", PERF_MODELS) +def test_simple_cpu_offload_perf_latency_lazy(model: str): + """Lazy mode: CPU KV hit should beat cold prefill (large models only).""" + # CPU must be larger than GPU KV cache to avoid evicting offloaded blocks. + llm = _make_llm(model, True, 80 << 30) # 80GB + try: + _latency_test(llm, lazy=True) + finally: + del llm diff --git a/tests/v1/simple_kv_offload/test_scheduler.py b/tests/v1/simple_kv_offload/test_scheduler.py new file mode 100644 index 000000000..132f52fe3 --- /dev/null +++ b/tests/v1/simple_kv_offload/test_scheduler.py @@ -0,0 +1,1137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for SimpleCPUOffloadScheduler.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from vllm import SamplingParams +from vllm.config import ( + CacheConfig, + DeviceConfig, + KVTransferConfig, + ModelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.utils.hashing import sha256 +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import ( + get_request_block_hasher, + init_none_hash, + make_block_hash_with_group_id, +) +from vllm.v1.core.sched.output import ( + CachedRequestData, + NewRequestData, + SchedulerOutput, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheTensor, +) +from vllm.v1.outputs import KVConnectorOutput +from vllm.v1.request import Request +from vllm.v1.simple_kv_offload.manager import SimpleCPUOffloadScheduler +from vllm.v1.simple_kv_offload.metadata import SimpleCPUOffloadWorkerMetadata + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +BLOCK_SIZE = 16 +HEAD_SIZE = 16 +NUM_KV_HEADS = 1 +DTYPE = torch.float16 +# bytes per block per tensor: +# block_size * num_kv_heads * head_size * 2 (K+V) * element_size +_BYTES_PER_BLOCK = BLOCK_SIZE * NUM_KV_HEADS * HEAD_SIZE * 2 * DTYPE.itemsize + +# Ensure none_hash is initialized once +init_none_hash(sha256) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_kv_cache_config( + num_blocks: int, + num_groups: int = 1, +) -> KVCacheConfig: + """Build a KVCacheConfig with non-empty kv_cache_tensors.""" + groups = [] + tensors = [] + for g in range(num_groups): + layer_names = [f"layer_{g}"] + groups.append( + KVCacheGroupSpec( + layer_names, + FullAttentionSpec( + block_size=BLOCK_SIZE, + num_kv_heads=NUM_KV_HEADS, + head_size=HEAD_SIZE, + dtype=DTYPE, + ), + ) + ) + tensors.append( + KVCacheTensor( + size=_BYTES_PER_BLOCK * num_blocks, + shared_by=layer_names, + ) + ) + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=tensors, + kv_cache_groups=groups, + ) + + +def _make_vllm_config(block_size: int = BLOCK_SIZE) -> VllmConfig: + """Minimal VllmConfig for scheduler tests (no GPU).""" + model_config = ModelConfig( + model="facebook/opt-125m", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + scheduler_config = SchedulerConfig( + max_num_seqs=16, + max_num_batched_tokens=64, + max_model_len=10000, + enable_chunked_prefill=True, + is_encoder_decoder=False, + ) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + enable_prefix_caching=True, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="SimpleCPUOffloadConnector", + kv_role="kv_both", + ) + return VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu"), + ) + + +@dataclass +class SchedulerFixture: + """Bundle returned by make_scheduler for convenient access.""" + + scheduler: SimpleCPUOffloadScheduler + gpu_block_pool: BlockPool + vllm_config: VllmConfig + kv_cache_config: KVCacheConfig + num_groups: int = 1 + + +def make_scheduler( + num_cpu_blocks: int = 8, + num_gpu_blocks: int = 16, + num_groups: int = 1, + lazy: bool = False, +) -> SchedulerFixture: + """Build a SimpleCPUOffloadScheduler with small block pools.""" + kv_cache_config = _make_kv_cache_config(num_gpu_blocks, num_groups) + vllm_config = _make_vllm_config() + cpu_capacity_bytes = _BYTES_PER_BLOCK * num_cpu_blocks * num_groups + + sched = SimpleCPUOffloadScheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + cpu_capacity_bytes=cpu_capacity_bytes, + lazy_offload=lazy, + ) + + # Build a real GPU block pool and bind it + gpu_block_pool = BlockPool( + num_gpu_blocks=num_gpu_blocks, + enable_caching=True, + hash_block_size=BLOCK_SIZE, + ) + sched.bind_gpu_block_pool(gpu_block_pool) + + return SchedulerFixture( + scheduler=sched, + gpu_block_pool=gpu_block_pool, + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + num_groups=num_groups, + ) + + +_req_counter = 0 + + +def make_request( + num_blocks: int = 2, + request_id: str | None = None, +) -> Request: + """Create a Request with deterministic block hashes.""" + global _req_counter + _req_counter += 1 + if request_id is None: + request_id = f"req-{_req_counter}" + + num_tokens = num_blocks * BLOCK_SIZE + start = _req_counter * 10000 + prompt_token_ids = list(range(start, start + num_tokens)) + sampling_params = SamplingParams(max_tokens=1) + + req = Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + pooling_params=None, + mm_features=None, + block_hasher=get_request_block_hasher(BLOCK_SIZE, sha256), + ) + return req + + +def make_scheduler_output( + req_id_to_num_tokens: dict[str, int], + *, + new_reqs: dict[str, tuple[list[int], ...]] | None = None, + cached_req_new_blocks: dict[str, tuple[list[int], ...] | None] | None = None, +) -> SchedulerOutput: + """Build a minimal SchedulerOutput with num_scheduled_tokens. + + Args: + new_reqs: For first-time requests, maps req_id -> block_ids tuple. + These are placed into ``scheduled_new_reqs`` as ``NewRequestData``. + cached_req_new_blocks: For returning (cached) requests, maps + req_id -> new_block_ids (incremental) or None. + These are placed into ``scheduled_cached_reqs``. + """ + scheduled_new_reqs: list[NewRequestData] = [] + if new_reqs: + for req_id, block_ids in new_reqs.items(): + scheduled_new_reqs.append( + NewRequestData( + req_id=req_id, + prompt_token_ids=None, + mm_features=[], + sampling_params=None, + pooling_params=None, + block_ids=block_ids, + num_computed_tokens=0, + lora_request=None, + ) + ) + + if cached_req_new_blocks: + cached_req_ids = list(cached_req_new_blocks.keys()) + cached_new_block_ids = [cached_req_new_blocks[rid] for rid in cached_req_ids] + cached_reqs = CachedRequestData( + req_ids=cached_req_ids, + resumed_req_ids=set(), + new_token_ids=[[] for _ in cached_req_ids], + all_token_ids={}, + new_block_ids=cached_new_block_ids, + num_computed_tokens=[0] * len(cached_req_ids), + num_output_tokens=[0] * len(cached_req_ids), + ) + else: + cached_reqs = CachedRequestData.make_empty() + + return SchedulerOutput( + scheduled_new_reqs=scheduled_new_reqs, + scheduled_cached_reqs=cached_reqs, + num_scheduled_tokens=req_id_to_num_tokens, + total_num_scheduled_tokens=sum(req_id_to_num_tokens.values()), + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[], + preempted_req_ids=set(), + finished_req_ids=set(), + free_encoder_mm_hashes=[], + ) + + +def simulate_store_completion( + scheduler: SimpleCPUOffloadScheduler, + event_idx: int, +) -> None: + """Simulate worker reporting a store event completion.""" + output = KVConnectorOutput( + finished_recving=set(), + kv_connector_worker_meta=SimpleCPUOffloadWorkerMetadata( + completed_store_events={event_idx: scheduler._expected_worker_count}, + ), + ) + scheduler.update_connector_output(output) + + +def simulate_load_completion( + scheduler: SimpleCPUOffloadScheduler, + req_ids: set[str], +) -> None: + """Simulate worker reporting load completions for requests.""" + output = KVConnectorOutput( + finished_sending=set(), + finished_recving=req_ids, + ) + scheduler.update_connector_output(output) + + +def get_cpu_free_blocks(scheduler: SimpleCPUOffloadScheduler) -> int: + """Return number of free CPU blocks.""" + return scheduler.cpu_block_pool.get_num_free_blocks() + + +def _allocate_gpu_blocks( + gpu_block_pool: BlockPool, + request: Request, + num_blocks: int, + group_id: int = 0, +) -> list: + """Allocate GPU blocks, cache them with hashes, return block list. + + Mimics what KVCacheManager does: allocate blocks from pool, then + register them in the prefix cache via cache_full_blocks so that + re-allocation properly evicts stale hashes. + """ + blocks = gpu_block_pool.get_new_blocks(num_blocks) + num_full = min(num_blocks, len(request.block_hashes)) + if num_full > 0: + gpu_block_pool.cache_full_blocks( + request=request, + blocks=blocks, + num_cached_blocks=0, + num_full_blocks=num_full, + block_size=BLOCK_SIZE, + kv_cache_group_id=group_id, + ) + return blocks + + +def _alloc_and_register( + fix: SchedulerFixture, + request: Request, + num_blocks: int, + *, + confirmed: bool = True, + group_id: int = 0, +) -> KVCacheBlocks: + """Allocate GPU blocks and return KVCacheBlocks. + + Block IDs are no longer registered in a mock KVCacheManager; instead + tests pass them through ``make_scheduler_output`` so that + ``yield_req_data`` can pick them up. + + If ``confirmed`` is True, advance ``request.num_computed_tokens`` to simulate + the scheduler's ``_update_after_schedule`` from a prior step. + """ + gpu_blocks = _allocate_gpu_blocks( + fix.gpu_block_pool, request, num_blocks, group_id=group_id + ) + kv_blocks = KVCacheBlocks(blocks=(gpu_blocks,)) + if confirmed: + request.num_computed_tokens = num_blocks * BLOCK_SIZE + return kv_blocks + + +# --------------------------------------------------------------------------- +# Test 1a: Eager store-and-load roundtrip +# --------------------------------------------------------------------------- +def test_eager_store_and_load_roundtrip() -> None: + """Eager mode: store blocks on compute, complete store, verify cache hit.""" + fix = make_scheduler(num_cpu_blocks=8, num_gpu_blocks=16, lazy=False) + sched = fix.scheduler + + num_blocks = 2 + req = make_request(num_blocks=num_blocks) + + kv_blocks = _alloc_and_register(fix, req, num_blocks) + sched.update_state_after_alloc(req, kv_blocks, num_external_tokens=0) + block_ids = kv_blocks.get_block_ids() + sched_out = make_scheduler_output( + {req.request_id: num_blocks * BLOCK_SIZE}, + new_reqs={req.request_id: block_ids}, + ) + + meta = sched.build_connector_meta(sched_out) + assert meta.store_event >= 0, "Expected a store event to be scheduled" + assert len(meta.store_gpu_blocks) > 0 + assert len(meta.store_cpu_blocks) == len(meta.store_gpu_blocks) + simulate_store_completion(sched, meta.store_event) + + # New request with same tokens should get CPU cache hit + req2 = Request( + request_id="req-eager-load", + prompt_token_ids=req.prompt_token_ids, + sampling_params=req.sampling_params, + pooling_params=None, + mm_features=None, + block_hasher=req._block_hasher, + ) + hit_tokens, is_async = sched.get_num_new_matched_tokens(req2, num_computed_tokens=0) + assert hit_tokens == num_blocks * BLOCK_SIZE + assert is_async is True + + gpu_blocks2 = fix.gpu_block_pool.get_new_blocks(num_blocks) + kv_blocks2 = KVCacheBlocks(blocks=(gpu_blocks2,)) + sched.update_state_after_alloc(req2, kv_blocks2, num_external_tokens=hit_tokens) + + block_ids2 = kv_blocks2.get_block_ids() + sched_out2 = make_scheduler_output( + {req2.request_id: 1}, + new_reqs={req2.request_id: block_ids2}, + ) + meta2 = sched.build_connector_meta(sched_out2) + assert meta2.load_event >= 0, "Expected a load event to be assigned" + assert len(meta2.load_gpu_blocks) > 0 + assert len(meta2.load_cpu_blocks) == len(meta2.load_gpu_blocks) + + +# --------------------------------------------------------------------------- +# Test 1b: Lazy store-and-load roundtrip +# --------------------------------------------------------------------------- +def _flush_old_blocks_to_lru_head( + gpu_pool: BlockPool, + num_filler_blocks: int, +) -> list: + """Allocate filler blocks so that previously-freed (hashed) blocks migrate + to the LRU head of the free queue. Returns the filler blocks (caller must + free them later to restore pool capacity). + + In a real engine the same thing happens naturally: after one request + finishes and frees its blocks, subsequent requests allocate from the LRU + head, consuming the unhashed blocks and leaving the old hashed blocks at + the front of the queue. + """ + fillers = gpu_pool.get_new_blocks(num_filler_blocks) + return fillers + + +def test_lazy_store_and_load_roundtrip() -> None: + """Lazy mode: schedule a request, finish it so its hashed blocks are freed, + then schedule new requests so the old blocks migrate to the LRU head. + The lazy scanner offloads them to CPU. Re-scheduling the old request + triggers a CPU cache hit + load. + + GPU pool: 8 blocks (7 usable). _target_free = ceil(64/16) = 4. + """ + fix = make_scheduler(num_cpu_blocks=8, num_gpu_blocks=8, lazy=True) + sched = fix.scheduler + gpu_pool = fix.gpu_block_pool + + num_blocks = 2 + + # --- Step 1: Schedule req_old, compute, and finish --- + req_old = make_request(num_blocks=num_blocks) + gpu_blocks_old = _allocate_gpu_blocks(gpu_pool, req_old, num_blocks, group_id=0) + gpu_pool.free_blocks(gpu_blocks_old) + + # Allocate filler blocks so req_old's hashed blocks move to LRU head. + # 7 usable - 2 (req_old freed) = 5 other free blocks to consume. + fillers = _flush_old_blocks_to_lru_head(gpu_pool, num_filler_blocks=5) + + # --- Step 2: Lazy scanner should offload req_old's blocks --- + sched_out = make_scheduler_output({}) + meta = sched.build_connector_meta(sched_out) + assert meta.store_event >= 0, "Expected lazy store to offload old blocks" + assert len(meta.store_gpu_blocks) == num_blocks + simulate_store_completion(sched, meta.store_event) + + # Free fillers to restore pool capacity. + gpu_pool.free_blocks(fillers) + + # --- Step 3: Re-schedule req_old — should get CPU cache hit --- + req_old2 = Request( + request_id="req-old-reload", + prompt_token_ids=req_old.prompt_token_ids, + sampling_params=req_old.sampling_params, + pooling_params=None, + mm_features=None, + block_hasher=req_old._block_hasher, + ) + hit_tokens, is_async = sched.get_num_new_matched_tokens( + req_old2, num_computed_tokens=0 + ) + assert hit_tokens == num_blocks * BLOCK_SIZE, ( + f"Expected {num_blocks * BLOCK_SIZE} hit tokens, got {hit_tokens}" + ) + assert is_async is True + + # Allocate fresh GPU blocks for the load. + gpu_blocks_load = gpu_pool.get_new_blocks(num_blocks) + kv_blocks_load = KVCacheBlocks(blocks=(gpu_blocks_load,)) + sched.update_state_after_alloc( + req_old2, kv_blocks_load, num_external_tokens=hit_tokens + ) + + sched_out2 = make_scheduler_output({req_old2.request_id: 1}) + meta2 = sched.build_connector_meta(sched_out2) + assert meta2.load_event >= 0, "Expected a load event to be assigned" + assert len(meta2.load_gpu_blocks) > 0 + + +# --------------------------------------------------------------------------- +# Test 2a: Eager duplicate store is skipped +# --------------------------------------------------------------------------- +def test_eager_duplicate_store_skipped() -> None: + """Eager: storing the same block hashes twice should not allocate new CPU blocks.""" + fix = make_scheduler(num_cpu_blocks=8, num_gpu_blocks=16, lazy=False) + sched = fix.scheduler + + num_blocks = 2 + req = make_request(num_blocks=num_blocks) + + kv_blocks = _alloc_and_register(fix, req, num_blocks) + sched.update_state_after_alloc(req, kv_blocks, num_external_tokens=0) + block_ids = kv_blocks.get_block_ids() + sched_out = make_scheduler_output( + {req.request_id: num_blocks * BLOCK_SIZE}, + new_reqs={req.request_id: block_ids}, + ) + + meta1 = sched.build_connector_meta(sched_out) + assert meta1.store_event >= 0 + simulate_store_completion(sched, meta1.store_event) + cpu_free_after_first = get_cpu_free_blocks(sched) + + # Second request with identical hashes — should skip store + req2 = Request( + request_id="req-dup-eager", + prompt_token_ids=req.prompt_token_ids, + sampling_params=req.sampling_params, + pooling_params=None, + mm_features=None, + block_hasher=req._block_hasher, + ) + kv_blocks2 = _alloc_and_register(fix, req2, num_blocks) + sched.update_state_after_alloc(req2, kv_blocks2, num_external_tokens=0) + block_ids2 = kv_blocks2.get_block_ids() + sched_out2 = make_scheduler_output( + {req2.request_id: num_blocks * BLOCK_SIZE}, + new_reqs={req2.request_id: block_ids2}, + ) + + meta2 = sched.build_connector_meta(sched_out2) + if meta2.store_event >= 0: + assert len(meta2.store_cpu_blocks) == 0, ( + "Expected no new CPU blocks for duplicate hashes" + ) + assert get_cpu_free_blocks(sched) == cpu_free_after_first + + +# --------------------------------------------------------------------------- +# Test 2b: Lazy duplicate store is skipped +# --------------------------------------------------------------------------- +def test_lazy_duplicate_store_skipped() -> None: + """Lazy: blocks already offloaded to CPU should not be offloaded again. + + Same pattern as the lazy roundtrip: flush old blocks to LRU head, offload, + then repeat with the same hashes and verify no new CPU allocation. + """ + fix = make_scheduler(num_cpu_blocks=8, num_gpu_blocks=8, lazy=True) + sched = fix.scheduler + gpu_pool = fix.gpu_block_pool + + num_blocks = 2 + req = make_request(num_blocks=num_blocks) + + # Schedule + finish → hashed blocks in free queue + gpu_blocks = _allocate_gpu_blocks(gpu_pool, req, num_blocks, group_id=0) + gpu_pool.free_blocks(gpu_blocks) + + # Flush old blocks to LRU head, then trigger lazy offload. + fillers = _flush_old_blocks_to_lru_head(gpu_pool, num_filler_blocks=5) + meta1 = sched.build_connector_meta(make_scheduler_output({})) + assert meta1.store_event >= 0 + simulate_store_completion(sched, meta1.store_event) + gpu_pool.free_blocks(fillers) + cpu_free_after_first = get_cpu_free_blocks(sched) + + # Allocate blocks with the same hashes and free them again. + # The scanner should see they are already in CPU cache and skip them. + req2 = Request( + request_id="req-dup-lazy", + prompt_token_ids=req.prompt_token_ids, + sampling_params=req.sampling_params, + pooling_params=None, + mm_features=None, + block_hasher=req._block_hasher, + ) + gpu_blocks2 = _allocate_gpu_blocks(gpu_pool, req2, num_blocks, group_id=0) + gpu_pool.free_blocks(gpu_blocks2) + + # Flush again so the hashed blocks are at LRU head for the scanner. + fillers2 = _flush_old_blocks_to_lru_head(gpu_pool, num_filler_blocks=5) + meta2 = sched.build_connector_meta(make_scheduler_output({})) + gpu_pool.free_blocks(fillers2) + + # Either no store event, or zero new CPU blocks (already cached). + if meta2.store_event >= 0: + assert len(meta2.store_cpu_blocks) == 0, ( + "Expected no new CPU blocks for duplicate hashes" + ) + assert get_cpu_free_blocks(sched) == cpu_free_after_first + + +# --------------------------------------------------------------------------- +# Test 3: LRU eviction order +# --------------------------------------------------------------------------- +def test_lru_eviction_order() -> None: + """With limited CPU space, oldest blocks should be evicted first. + + CPU block pool: num_cpu_blocks=5 -> 4 free usable blocks (1 taken by null_block). + After storing 4 blocks (2 req_a + 2 req_b), all free slots are occupied by + cached blocks (ref_cnt=0, in hash map). When 2 more are stored (req_c), + 2 LRU blocks from req_a get evicted from the cache to make room. + """ + # 5 total = 4 usable (null_block takes 1), filling exactly with 4 blocks + fix = make_scheduler(num_cpu_blocks=5, num_gpu_blocks=16, lazy=False) + sched = fix.scheduler + + # Fill CPU with 4 blocks: 2 requests x 2 blocks (in LRU insertion order) + req_a = make_request(num_blocks=2) + req_b = make_request(num_blocks=2) + + kv_a = _alloc_and_register(fix, req_a, 2) + kv_b = _alloc_and_register(fix, req_b, 2) + sched.update_state_after_alloc(req_a, kv_a, num_external_tokens=0) + sched.update_state_after_alloc(req_b, kv_b, num_external_tokens=0) + + ids_a = kv_a.get_block_ids() + ids_b = kv_b.get_block_ids() + sched_out = make_scheduler_output( + { + req_a.request_id: 2 * BLOCK_SIZE, + req_b.request_id: 2 * BLOCK_SIZE, + }, + new_reqs={ + req_a.request_id: ids_a, + req_b.request_id: ids_b, + }, + ) + meta = sched.build_connector_meta(sched_out) + assert meta.store_event >= 0 + simulate_store_completion(sched, meta.store_event) + + # Verify all 4 blocks are cached in CPU hash map + for i, bhash in enumerate(req_a.block_hashes[:2]): + bhash_with_group = make_block_hash_with_group_id(bhash, 0) + assert ( + sched.cpu_block_pool.cached_block_hash_to_block.get_one_block( + bhash_with_group + ) + is not None + ), f"req_a block {i} should be cached after store" + for i, bhash in enumerate(req_b.block_hashes[:2]): + bhash_with_group = make_block_hash_with_group_id(bhash, 0) + assert ( + sched.cpu_block_pool.cached_block_hash_to_block.get_one_block( + bhash_with_group + ) + is not None + ), f"req_b block {i} should be cached after store" + + # Store 2 more blocks from a new request - must evict 2 LRU blocks (req_a) + req_c = make_request(num_blocks=2) + kv_c = _alloc_and_register(fix, req_c, 2) + sched.update_state_after_alloc(req_c, kv_c, num_external_tokens=0) + + ids_c = kv_c.get_block_ids() + sched_out2 = make_scheduler_output( + {req_c.request_id: 2 * BLOCK_SIZE}, + new_reqs={req_c.request_id: ids_c}, + ) + meta2 = sched.build_connector_meta(sched_out2) + assert meta2.store_event >= 0 + simulate_store_completion(sched, meta2.store_event) + + # req_a hashes should be evicted from CPU (they were LRU) + for i, bhash in enumerate(req_a.block_hashes[:2]): + bhash_with_group = make_block_hash_with_group_id(bhash, 0) + cache_map = sched.cpu_block_pool.cached_block_hash_to_block + cached = cache_map.get_one_block(bhash_with_group) + assert cached is None, f"req_a block {i} should have been evicted" + + # req_b and req_c hashes should be present + for i, bhash in enumerate(req_b.block_hashes[:2]): + bhash_with_group = make_block_hash_with_group_id(bhash, 0) + cache_map = sched.cpu_block_pool.cached_block_hash_to_block + cached = cache_map.get_one_block(bhash_with_group) + assert cached is not None, f"req_b block {i} should still be cached" + + for i, bhash in enumerate(req_c.block_hashes[:2]): + bhash_with_group = make_block_hash_with_group_id(bhash, 0) + cache_map = sched.cpu_block_pool.cached_block_hash_to_block + cached = cache_map.get_one_block(bhash_with_group) + assert cached is not None, f"req_c block {i} should still be cached" + + +# --------------------------------------------------------------------------- +# Test 4: Touched blocks survive eviction +# --------------------------------------------------------------------------- +def test_touched_blocks_survive_eviction() -> None: + """Touching CPU blocks updates their LRU position, protecting them from eviction.""" + # 5 total = 4 usable (null_block takes 1) + fix = make_scheduler(num_cpu_blocks=5, num_gpu_blocks=16, lazy=False) + sched = fix.scheduler + + # Fill CPU with 4 blocks (req_a: 2, req_b: 2) in LRU order + req_a = make_request(num_blocks=2) + req_b = make_request(num_blocks=2) + + kv_a = _alloc_and_register(fix, req_a, 2) + kv_b = _alloc_and_register(fix, req_b, 2) + sched.update_state_after_alloc(req_a, kv_a, num_external_tokens=0) + sched.update_state_after_alloc(req_b, kv_b, num_external_tokens=0) + + ids_a = kv_a.get_block_ids() + ids_b = kv_b.get_block_ids() + sched_out = make_scheduler_output( + { + req_a.request_id: 2 * BLOCK_SIZE, + req_b.request_id: 2 * BLOCK_SIZE, + }, + new_reqs={ + req_a.request_id: ids_a, + req_b.request_id: ids_b, + }, + ) + meta = sched.build_connector_meta(sched_out) + simulate_store_completion(sched, meta.store_event) + + # Touch req_a's CPU blocks to make them most-recently-used + cpu_pool = sched.cpu_block_pool + for bhash in req_a.block_hashes[:2]: + bhash_with_group = make_block_hash_with_group_id(bhash, 0) + cached_blk = cpu_pool.cached_block_hash_to_block.get_one_block(bhash_with_group) + assert cached_blk is not None + cpu_pool.touch([cached_blk]) + # Undo touch to return ref_cnt to 0 + # (so it's a free candidate but at MRU position) + cpu_pool.free_blocks([cached_blk]) + + # Now store 2 more blocks; req_b (LRU front) should be evicted, not req_a + req_c = make_request(num_blocks=2) + kv_c = _alloc_and_register(fix, req_c, 2) + sched.update_state_after_alloc(req_c, kv_c, num_external_tokens=0) + + ids_c = kv_c.get_block_ids() + sched_out2 = make_scheduler_output( + {req_c.request_id: 2 * BLOCK_SIZE}, + new_reqs={req_c.request_id: ids_c}, + ) + meta2 = sched.build_connector_meta(sched_out2) + simulate_store_completion(sched, meta2.store_event) + + # req_b should be evicted (LRU), req_a and req_c should survive + for i, bhash in enumerate(req_b.block_hashes[:2]): + bhash_with_group = make_block_hash_with_group_id(bhash, 0) + cached = cpu_pool.cached_block_hash_to_block.get_one_block(bhash_with_group) + assert cached is None, f"req_b block {i} should have been evicted (it was LRU)" + + for i, bhash in enumerate(req_a.block_hashes[:2]): + bhash_with_group = make_block_hash_with_group_id(bhash, 0) + cached = cpu_pool.cached_block_hash_to_block.get_one_block(bhash_with_group) + assert cached is not None, f"req_a block {i} should survive (was touched/MRU)" + + +# --------------------------------------------------------------------------- +# Test 5: Preemption no CPU block leak +# --------------------------------------------------------------------------- +def test_preemption_no_cpu_block_leak() -> None: + """request_finished during in-flight load defers cleanup; + completes after load done.""" + fix = make_scheduler(num_cpu_blocks=8, num_gpu_blocks=16, lazy=False) + sched = fix.scheduler + + num_blocks = 2 + + # First: store blocks to CPU + req = make_request(num_blocks=num_blocks) + kv_blocks = _alloc_and_register(fix, req, num_blocks) + sched.update_state_after_alloc(req, kv_blocks, num_external_tokens=0) + block_ids = kv_blocks.get_block_ids() + sched_out = make_scheduler_output( + {req.request_id: num_blocks * BLOCK_SIZE}, + new_reqs={req.request_id: block_ids}, + ) + meta = sched.build_connector_meta(sched_out) + simulate_store_completion(sched, meta.store_event) + + # Create new request with same tokens, check hit + req2 = Request( + request_id="req-preempt-load", + prompt_token_ids=req.prompt_token_ids, + sampling_params=req.sampling_params, + pooling_params=None, + mm_features=None, + block_hasher=req._block_hasher, + ) + hit_tokens, is_async = sched.get_num_new_matched_tokens(req2, num_computed_tokens=0) + assert hit_tokens > 0 + + gpu_blocks2 = fix.gpu_block_pool.get_new_blocks(num_blocks) + kv_blocks2 = KVCacheBlocks(blocks=(gpu_blocks2,)) + sched.update_state_after_alloc(req2, kv_blocks2, num_external_tokens=hit_tokens) + + # Assign load_event via build_connector_meta + block_ids2 = kv_blocks2.get_block_ids() + sched_out2 = make_scheduler_output( + {req2.request_id: 1}, + new_reqs={req2.request_id: block_ids2}, + ) + meta2 = sched.build_connector_meta(sched_out2) + assert meta2.load_event >= 0 + + # Request finishes BEFORE load completes -> deferred + sched.request_finished(req2, block_ids=[]) + assert req2.request_id in sched._reqs_to_load + assert sched._reqs_to_load[req2.request_id].finished is True + + # Now simulate load completion -> cleanup fires + simulate_load_completion(sched, {req2.request_id}) + assert req2.request_id not in sched._reqs_to_load + + +# --------------------------------------------------------------------------- +# Test 6: Eager store preemption cleanup +# --------------------------------------------------------------------------- +def test_eager_store_preemption_cleanup() -> None: + """In eager mode, finishing a request during in-flight store defers cleanup.""" + fix = make_scheduler(num_cpu_blocks=8, num_gpu_blocks=16, lazy=False) + sched = fix.scheduler + + num_blocks = 2 + req = make_request(num_blocks=num_blocks) + kv_blocks = _alloc_and_register(fix, req, num_blocks) + sched.update_state_after_alloc(req, kv_blocks, num_external_tokens=0) + + block_ids = kv_blocks.get_block_ids() + sched_out = make_scheduler_output( + {req.request_id: num_blocks * BLOCK_SIZE}, + new_reqs={req.request_id: block_ids}, + ) + meta = sched.build_connector_meta(sched_out) + store_event = meta.store_event + assert store_event >= 0 + + # The request gets store_events populated + assert req.request_id in sched._reqs_to_store + store_state = sched._reqs_to_store[req.request_id] + assert store_event in store_state.store_events + + # Finish request while store still in-flight -> deferred + sched.request_finished(req, block_ids=[]) + assert req.request_id in sched._reqs_to_store + assert sched._reqs_to_store[req.request_id].finished is True + + # Simulate store completion -> deferred cleanup fires + simulate_store_completion(sched, store_event) + assert req.request_id not in sched._reqs_to_store + + +# --------------------------------------------------------------------------- +# Test 7: In-flight finish deferred cleanup (load variant) +# --------------------------------------------------------------------------- +def test_inflight_finish_deferred_cleanup() -> None: + """Store, then start a load, request_finished defers, + load completion fires cleanup.""" + fix = make_scheduler(num_cpu_blocks=8, num_gpu_blocks=16, lazy=False) + sched = fix.scheduler + + num_blocks = 2 + + # Store + req = make_request(num_blocks=num_blocks) + kv_blocks = _alloc_and_register(fix, req, num_blocks) + sched.update_state_after_alloc(req, kv_blocks, num_external_tokens=0) + block_ids = kv_blocks.get_block_ids() + sched_out = make_scheduler_output( + {req.request_id: num_blocks * BLOCK_SIZE}, + new_reqs={req.request_id: block_ids}, + ) + meta = sched.build_connector_meta(sched_out) + simulate_store_completion(sched, meta.store_event) + + # Load + req2 = Request( + request_id="req-inflight-load", + prompt_token_ids=req.prompt_token_ids, + sampling_params=req.sampling_params, + pooling_params=None, + mm_features=None, + block_hasher=req._block_hasher, + ) + hit_tokens, _ = sched.get_num_new_matched_tokens(req2, num_computed_tokens=0) + assert hit_tokens > 0 + + gpu_blocks2 = fix.gpu_block_pool.get_new_blocks(num_blocks) + kv_blocks2 = KVCacheBlocks(blocks=(gpu_blocks2,)) + sched.update_state_after_alloc(req2, kv_blocks2, num_external_tokens=hit_tokens) + + block_ids2 = kv_blocks2.get_block_ids() + sched_out2 = make_scheduler_output( + {req2.request_id: 1}, + new_reqs={req2.request_id: block_ids2}, + ) + meta2 = sched.build_connector_meta(sched_out2) + assert meta2.load_event >= 0 + + # Finish before load completes + sched.request_finished(req2, block_ids=[]) + assert req2.request_id in sched._reqs_to_load + + # Simulate load completion -> request removed + simulate_load_completion(sched, {req2.request_id}) + assert req2.request_id not in sched._reqs_to_load + + +# --------------------------------------------------------------------------- +# Test 8: Null GPU blocks are skipped in store and load transfer pairs +# --------------------------------------------------------------------------- +def test_multi_group_null_blocks_skipped() -> None: + """Null GPU blocks (no block_hash) must not appear in store or load pairs. + + In eager store mode, _prepare_eager_store_specs skips blocks whose + block_hash is None (null blocks have no hash). We verify this by mixing + real hashed blocks with unhashed (null-like) blocks in a single group and + checking that only real blocks appear in the store list. + """ + fix = make_scheduler(num_cpu_blocks=8, num_gpu_blocks=16, num_groups=1, lazy=False) + sched = fix.scheduler + gpu_pool = fix.gpu_block_pool + + num_blocks = 2 + req = make_request(num_blocks=num_blocks) + + # Allocate real blocks (with hashes) and use the null_block as a placeholder + gpu_blocks = _allocate_gpu_blocks(gpu_pool, req, num_blocks, group_id=0) + null_block = gpu_pool.null_block + + # Mix: [real_block, null_block] — null_block has no hash, should be skipped + mixed_blocks = [gpu_blocks[0], null_block] + kv_blocks = KVCacheBlocks(blocks=(mixed_blocks,)) + req.num_computed_tokens = num_blocks * BLOCK_SIZE + sched.update_state_after_alloc(req, kv_blocks, num_external_tokens=0) + + block_ids = kv_blocks.get_block_ids() + sched_out = make_scheduler_output( + {req.request_id: num_blocks * BLOCK_SIZE}, + new_reqs={req.request_id: block_ids}, + ) + meta = sched.build_connector_meta(sched_out) + + # Null block's ID should NOT appear in store_gpu_blocks + null_block_id = null_block.block_id + assert null_block_id not in meta.store_gpu_blocks, ( + f"Null block id {null_block_id} should not appear in store transfer pairs" + ) + + # Only real block should be scheduled for store + assert len(meta.store_gpu_blocks) == 1 + assert gpu_blocks[0].block_id in meta.store_gpu_blocks + + # Complete the store + assert meta.store_event >= 0 + simulate_store_completion(sched, meta.store_event) + + # Create matching request and get load hit + req2 = Request( + request_id="req-null-load", + prompt_token_ids=req.prompt_token_ids, + sampling_params=req.sampling_params, + pooling_params=None, + mm_features=None, + block_hasher=req._block_hasher, + ) + hit_tokens, is_async = sched.get_num_new_matched_tokens(req2, num_computed_tokens=0) + # Only 1 block was stored (the real one) + assert hit_tokens == BLOCK_SIZE + assert is_async is True + + # Allocate new GPU blocks for the load + gpu_blocks2 = gpu_pool.get_new_blocks(1) + kv_blocks2 = KVCacheBlocks(blocks=([gpu_blocks2[0], null_block],)) + sched.update_state_after_alloc(req2, kv_blocks2, num_external_tokens=hit_tokens) + + sched_out2 = make_scheduler_output({req2.request_id: 1}) + meta2 = sched.build_connector_meta(sched_out2) + + # Null block's ID should NOT appear in load_gpu_blocks + assert null_block_id not in meta2.load_gpu_blocks, ( + f"Null block id {null_block_id} should not appear in load transfer pairs" + ) + + +# --------------------------------------------------------------------------- +# Test 9: Chunked prefill accumulates block_ids across steps +# --------------------------------------------------------------------------- +def test_chunked_prefill_reads_live_block_ids() -> None: + """With chunked prefill, block IDs accumulate across scheduler steps. + _prepare_eager_store_specs reads block IDs from scheduler_output via + yield_req_data, so the store should reflect the updated (larger) block + list, not a stale snapshot.""" + fix = make_scheduler(num_cpu_blocks=8, num_gpu_blocks=16, lazy=False) + sched = fix.scheduler + + num_blocks = 4 + req = make_request(num_blocks=num_blocks) + + # First chunk: allocate 2 blocks + kv_blocks_first = _alloc_and_register(fix, req, 2) + sched.update_state_after_alloc(req, kv_blocks_first, num_external_tokens=0) + + assert req.request_id in sched._reqs_to_store + # Should still be exactly 1 entry in _reqs_to_store + assert list(sched._reqs_to_store.keys()).count(req.request_id) == 1 + + # Build connector meta with 2 blocks — stores the first 2 + ids_first = kv_blocks_first.get_block_ids() + sched_out1 = make_scheduler_output( + {req.request_id: 2 * BLOCK_SIZE}, + new_reqs={req.request_id: ids_first}, + ) + meta1 = sched.build_connector_meta(sched_out1) + assert meta1.store_event >= 0 + assert len(meta1.store_gpu_blocks) == 2 + simulate_store_completion(sched, meta1.store_event) + + # Second chunk: allocate 4 blocks total (2 new ones) + kv_blocks_second = _alloc_and_register(fix, req, num_blocks) + # update_state_after_alloc is idempotent for store registration + sched.update_state_after_alloc(req, kv_blocks_second, num_external_tokens=0) + + # Still exactly 1 entry + assert list(sched._reqs_to_store.keys()).count(req.request_id) == 1 + + # The second chunk's NEW block IDs (positions 2,3) are passed as + # cached_req_new_blocks. The full block_ids include both old and new, + # but yield_req_data only appends the new_block_ids for cached reqs. + ids_second_full = kv_blocks_second.get_block_ids() + # New blocks are those beyond the first chunk + new_block_ids = tuple(ids_second_full[g][2:] for g in range(len(ids_second_full))) + sched_out2 = make_scheduler_output( + {req.request_id: 2 * BLOCK_SIZE}, + cached_req_new_blocks={req.request_id: new_block_ids}, + ) + meta2 = sched.build_connector_meta(sched_out2) + assert meta2.store_event >= 0 + # Only the 2 NEW blocks should be stored (first 2 already done) + assert len(meta2.store_gpu_blocks) == 2 + + +# --------------------------------------------------------------------------- +# Test 10: Partial GPU prefix hit + CPU load + new compute blocks +# --------------------------------------------------------------------------- +def test_partial_gpu_prefix_plus_cpu_load() -> None: + """When GPU has a prefix cache hit for the first N blocks, CPU has a + hit for the next M blocks, and there are P new blocks needing fresh + compute, the block layout is: + + | comp (N) | ext_comp (M) | new (P) | + + External blocks sit in the middle — not at the beginning or end. + The load path must target hashes at positions [N, N+M). + + Request: 6 blocks (0..5). + - Store all 6 to CPU. + - New request: GPU prefix cache hits blocks 0,1 (hashed). + CPU hits blocks 2,3. Blocks 4,5 are new (need compute). + - update_state_after_alloc receives 6 GPU blocks: + [0,1] hashed (comp), [2,3] unhashed (ext_comp), [4,5] unhashed (new). + - Load must target hash positions 2,3. + """ + fix = make_scheduler(num_cpu_blocks=8, num_gpu_blocks=16, lazy=False) + sched = fix.scheduler + gpu_pool = fix.gpu_block_pool + + num_blocks = 6 + req = make_request(num_blocks=num_blocks) + + # Store all 6 blocks to CPU via eager store. + kv_blocks = _alloc_and_register(fix, req, num_blocks) + sched.update_state_after_alloc(req, kv_blocks, num_external_tokens=0) + block_ids = kv_blocks.get_block_ids() + sched_out = make_scheduler_output( + {req.request_id: num_blocks * BLOCK_SIZE}, + new_reqs={req.request_id: block_ids}, + ) + meta = sched.build_connector_meta(sched_out) + assert meta.store_event >= 0 + simulate_store_completion(sched, meta.store_event) + + # New request with same tokens — but only partial GPU prefix hit. + req2 = Request( + request_id="req-partial-gpu", + prompt_token_ids=req.prompt_token_ids, + sampling_params=req.sampling_params, + pooling_params=None, + mm_features=None, + block_hasher=req._block_hasher, + ) + + # GPU prefix cache hits the first 2 blocks. + gpu_local_computed = 2 * BLOCK_SIZE + hit_tokens, is_async = sched.get_num_new_matched_tokens( + req2, num_computed_tokens=gpu_local_computed + ) + # CPU should hit blocks 2,3 (not 4,5 — those are beyond the CPU range). + num_cpu_hit_blocks = 2 + # Actually CPU has all 6 stored; it returns hits starting from position 2. + # The number of CPU hit blocks = min(remaining request blocks, CPU cached). + # Here remaining = 6 - 2 = 4 blocks are in CPU, so hit = 4 * BLOCK_SIZE. + num_cpu_hit_blocks = 4 + assert hit_tokens == num_cpu_hit_blocks * BLOCK_SIZE, ( + f"Expected {num_cpu_hit_blocks * BLOCK_SIZE} CPU hit tokens, got {hit_tokens}" + ) + assert is_async is True + + # Simulate what the real scheduler does: only accept 2 of the 4 CPU hit + # blocks as external (e.g. due to budget constraints), leaving 2 new + # blocks for fresh compute. + num_ext_blocks = 2 + num_new_blocks = 2 + external_tokens = num_ext_blocks * BLOCK_SIZE + + # Build block list matching real layout: | comp(2) | ext_comp(2) | new(2) | + # comp: GPU prefix cache hit — blocks with hashes + gpu_comp = _allocate_gpu_blocks(gpu_pool, req2, 2, group_id=0) + # ext_comp + new: freshly allocated, no hashes + gpu_ext_and_new = gpu_pool.get_new_blocks(num_ext_blocks + num_new_blocks) + all_gpu_blocks = gpu_comp + gpu_ext_and_new + kv_blocks2 = KVCacheBlocks(blocks=(all_gpu_blocks,)) + + # Critical call: with 2 hashed comp blocks and 2 external tokens worth + # of blocks, the manager must derive skipped=2 and load hashes [2,3]. + sched.update_state_after_alloc( + req2, kv_blocks2, num_external_tokens=external_tokens + ) + + block_ids2 = kv_blocks2.get_block_ids() + sched_out2 = make_scheduler_output( + {req2.request_id: num_new_blocks * BLOCK_SIZE}, + new_reqs={req2.request_id: block_ids2}, + ) + meta2 = sched.build_connector_meta(sched_out2) + assert meta2.load_event >= 0, "Expected a load event for partial GPU + CPU hit" + assert len(meta2.load_gpu_blocks) == num_ext_blocks + assert len(meta2.load_cpu_blocks) == num_ext_blocks + + # Verify the load targets the ext_comp GPU blocks (positions 2,3), + # not the comp blocks (0,1) or new blocks (4,5). + ext_block_ids = [b.block_id for b in gpu_ext_and_new[:num_ext_blocks]] + for bid in meta2.load_gpu_blocks: + assert bid in ext_block_ids, ( + f"Load GPU block {bid} should be an ext_comp block, not a comp or new block" + ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b6be7f10b..bc1ce80fc 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -657,7 +657,11 @@ class VllmConfig: ) if kv_offloading_backend == "native": - self.kv_transfer_config.kv_connector = "OffloadingConnector" + if envs.VLLM_USE_SIMPLE_KV_OFFLOAD: + config_connector = "SimpleCPUOffloadConnector" + else: + config_connector = "OffloadingConnector" + self.kv_transfer_config.kv_connector = config_connector self.kv_transfer_config.kv_connector_extra_config.update( {"cpu_bytes_to_use": kv_offloading_size * (1 << 30)} ) diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index b677c5885..c88b32284 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -202,6 +202,7 @@ KVConnectorFactory.register_connector( "vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector", "DecodeBenchConnector", ) + KVConnectorFactory.register_connector( "MooncakeConnector", "vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector", @@ -213,3 +214,9 @@ KVConnectorFactory.register_connector( "vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector", "FlexKVConnectorV1", ) + +KVConnectorFactory.register_connector( + "SimpleCPUOffloadConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector", + "SimpleCPUOffloadConnector", +) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/simple_cpu_offload_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/simple_cpu_offload_connector.py new file mode 100644 index 000000000..6475b941b --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/simple_cpu_offload_connector.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SimpleCPUOffloadConnector: minimal CPU KV cache offloading.""" + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_events import KVCacheEvent +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, + SupportsHMA, +) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import KVConnectorOutput +from vllm.v1.simple_kv_offload.manager import ( + SimpleCPUOffloadScheduler, +) +from vllm.v1.simple_kv_offload.metadata import ( + SimpleCPUOffloadMetadata, +) +from vllm.v1.simple_kv_offload.worker import ( + SimpleCPUOffloadWorker, +) + +if TYPE_CHECKING: + from vllm.forward_context import ForwardContext + from vllm.v1.attention.backend import AttentionMetadata + from vllm.v1.core.block_pool import BlockPool + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +logger = init_logger(__name__) + +# Default CPU capacity: 8 GB +DEFAULT_CPU_CAPACITY_BYTES = 8 * (1024**3) + + +class SimpleCPUOffloadConnector(KVConnectorBase_V1, SupportsHMA): + """CPU KV cache offloading with custom kernel transfers and BlockPool LRU.""" + + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig | None" = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + + enable_prefix_caching = vllm_config.cache_config.enable_prefix_caching + extra_config = self._kv_transfer_config.kv_connector_extra_config or {} + + cpu_capacity_bytes = int( + extra_config.get("cpu_bytes_to_use", DEFAULT_CPU_CAPACITY_BYTES) + ) + # cpu_bytes_to_use is server-wide for compatibility; + # cpu_bytes_to_use_per_rank overrides for per-rank capacity. + world_size = vllm_config.parallel_config.world_size + cpu_capacity_per_rank = cpu_capacity_bytes // world_size + if "cpu_bytes_to_use_per_rank" in extra_config: + explicit = int(extra_config["cpu_bytes_to_use_per_rank"]) + if explicit != cpu_capacity_per_rank: + logger.warning( + "cpu_bytes_to_use_per_rank (%.2f GB) != " + "cpu_bytes_to_use/world_size (%.2f GB). Using per-rank value.", + explicit / (1024**3), + cpu_capacity_per_rank / (1024**3), + ) + cpu_capacity_per_rank = explicit + + lazy_offload = bool(extra_config.get("lazy_offload", False)) + + self.scheduler_manager: SimpleCPUOffloadScheduler | None = None + self.worker_handler: SimpleCPUOffloadWorker | None = None + + if not enable_prefix_caching: + logger.warning( + "Detected prefix caching disabled, disabling CPU offload " + "since it requires prefix caching." + ) + return + + logger.info( + "SimpleCPUOffloadConnector: role=%s, " + "per_rank=%.2f GB, world_size=%d, mode=%s", + role.name, + cpu_capacity_per_rank / (1024**3), + world_size, + "lazy" if lazy_offload else "eager", + ) + + if role == KVConnectorRole.SCHEDULER: + self.scheduler_manager = SimpleCPUOffloadScheduler( + vllm_config, + kv_cache_config, + cpu_capacity_per_rank, + lazy_offload=lazy_offload, + ) + elif role == KVConnectorRole.WORKER: + self.worker_handler = SimpleCPUOffloadWorker( + vllm_config, kv_cache_config, cpu_capacity_per_rank + ) + + # --- Worker-side methods --- + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None: + if self.worker_handler is not None: + self.worker_handler.register_kv_caches(kv_caches) + + def bind_connector_metadata( + self, + connector_metadata: KVConnectorMetadata, + ) -> None: + super().bind_connector_metadata(connector_metadata) + if self.worker_handler is not None: + assert isinstance(connector_metadata, SimpleCPUOffloadMetadata) + self.worker_handler.bind_connector_metadata(connector_metadata) + + def clear_connector_metadata(self) -> None: + super().clear_connector_metadata() + if self.worker_handler is not None: + self.worker_handler.clear_connector_metadata() + + def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata) -> None: + if self.worker_handler is not None: + assert isinstance(kv_connector_metadata, SimpleCPUOffloadMetadata) + self.worker_handler.handle_preemptions(kv_connector_metadata) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: + pass # Launch loads ops in get_finished() after launching model execution + + def wait_for_layer_load(self, layer_name: str) -> None: + pass # Always load asynchronously and deferred to get_finished() + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: + pass # Always save asynchronously and deferred to get_finished() + + def wait_for_save(self) -> None: + pass # All stores are driven by get_finished() and no wait needed + + def get_finished( + self, + finished_req_ids: set[str], + ) -> tuple[set[str] | None, set[str] | None]: + if self.worker_handler is not None: + return self.worker_handler.get_finished(finished_req_ids) + return None, None + + def build_connector_worker_meta(self): + if self.worker_handler is not None: + return self.worker_handler.build_connector_worker_meta() + return None + + # --- Scheduler-side methods --- + + # NOTE: New API only for SimpleCPUOffloadConnector. + def bind_gpu_block_pool(self, gpu_block_pool: "BlockPool") -> None: + if self.scheduler_manager is not None: + self.scheduler_manager.bind_gpu_block_pool(gpu_block_pool) + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + if self.scheduler_manager is not None: + return self.scheduler_manager.get_num_new_matched_tokens( + request, num_computed_tokens + ) + return 0, False + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + ) -> None: + if self.scheduler_manager is not None: + self.scheduler_manager.update_state_after_alloc( + request, blocks, num_external_tokens + ) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + if self.scheduler_manager is not None: + return self.scheduler_manager.build_connector_meta(scheduler_output) + return SimpleCPUOffloadMetadata() + + def update_connector_output( + self, + connector_output: KVConnectorOutput, + ) -> None: + if self.scheduler_manager is not None: + self.scheduler_manager.update_connector_output(connector_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + if self.scheduler_manager is not None: + return self.scheduler_manager.request_finished(request, block_ids) + return False, None + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + if self.scheduler_manager is not None: + return self.scheduler_manager.request_finished_all_groups( + request, block_ids + ) + return False, None + + # NOTE: New API only for SimpleCPUOffloadConnector. + def has_pending_transfers(self) -> bool: + if self.scheduler_manager is not None: + return self.scheduler_manager.has_pending_stores() + return False + + def take_events(self) -> Iterable[KVCacheEvent]: + if self.scheduler_manager is not None: + return self.scheduler_manager.take_events() + return [] + + def reset_cache(self) -> bool | None: + raise NotImplementedError( + "SimpleCPUOffloadConnector does not support reset_cache(). " + "reset_prefix_cache() requires synchronizing all pending " + "CPU offload transfers before clearing GPU prefix cache blocks, " + "which is not yet implemented." + ) diff --git a/vllm/envs.py b/vllm/envs.py index 62c4609d1..6362373dd 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1662,6 +1662,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_XPU_ENABLE_XPU_GRAPH": lambda: bool( int(os.getenv("VLLM_XPU_ENABLE_XPU_GRAPH", "0")) ), + # Enable simple KV offload. + "VLLM_USE_SIMPLE_KV_OFFLOAD": lambda: bool( + int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0")) + ), } diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c28a5d18a..fe524ccac 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -234,6 +234,13 @@ class Scheduler(SchedulerInterface): hash_block_size=self.block_size, metrics_collector=self.kv_metrics_collector, ) + # Bind GPU block pool to the KV connector. This must happen after + # kv_cache_manager is constructed so block_pool is available. + if self.connector is not None and hasattr( + self.connector, "bind_gpu_block_pool" + ): + self.connector.bind_gpu_block_pool(self.kv_cache_manager.block_pool) + self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER self.scheduler_reserve_full_isl = ( diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 4a1e8b6f3..45f002e01 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -281,8 +281,16 @@ class PromptTokenStats: self.computed += prompt_len - num_cached_tokens self.external_kv_transfer += num_external_computed_tokens - self.local_cache_hit += ( - num_cached_tokens + recomputed - num_external_computed_tokens + # FIXME(yifan): local_cache_hit can go negative after preemption. + # num_cached_tokens is a one-time snapshot from first scheduling and + # is never reset on preemption, while num_external_computed_tokens is + # overwritten on re-scheduling. If CPU offload finds more tokens on + # the second pass than the original total, the subtraction underflows. + # A fundamental fix is to track the first-time num_external_computed_tokens + # as a separate metric rather than reusing num_external_computed_tokens + # for metric directly. + self.local_cache_hit += max( + 0, (num_cached_tokens + recomputed - num_external_computed_tokens) ) self.cached_tokens += num_cached_tokens self.recomputed_tokens += recomputed diff --git a/vllm/v1/simple_kv_offload/__init__.py b/vllm/v1/simple_kv_offload/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/v1/simple_kv_offload/copy_backend.py b/vllm/v1/simple_kv_offload/copy_backend.py new file mode 100644 index 000000000..114f26973 --- /dev/null +++ b/vllm/v1/simple_kv_offload/copy_backend.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""DMA copy backend for GPU<->CPU block transfers.""" + +from __future__ import annotations + +import queue +import threading + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.v1.simple_kv_offload.cuda_mem_ops import ( + BatchMemcpyParams, + build_params, + copy_blocks, +) + +logger = init_logger(__name__) + + +class DmaCopyBackend: + """cuMemcpyBatchAsync copy backend (background thread).""" + + def __init__(self) -> None: + self._store_params: BatchMemcpyParams | None = None + self._load_params: BatchMemcpyParams | None = None + self._load_stream: torch.cuda.Stream | None = None + self._store_stream: torch.cuda.Stream | None = None + self._queue: queue.SimpleQueue | None = None + self._thread: threading.Thread | None = None + self._shutdown: bool = False + + def init( + self, + gpu_caches: dict[str, torch.Tensor], + cpu_caches: dict[str, torch.Tensor], + device: torch.device, + load_stream: torch.cuda.Stream, + store_stream: torch.cuda.Stream, + ) -> None: + self._load_stream = load_stream + self._store_stream = store_stream + + self._store_params = build_params(gpu_caches, cpu_caches, store_stream) + self._load_params = build_params(cpu_caches, gpu_caches, load_stream) + + self._queue = queue.SimpleQueue() + self._thread = threading.Thread( + target=self._copy_loop, + args=(self._queue, device, load_stream, store_stream), + daemon=True, + ) + self._thread.start() + + def launch_copy( + self, + src_blocks: list[int], + dst_blocks: list[int], + is_store: bool, + event_idx: int, + events_list: list[tuple[int, torch.Event]], + ) -> None: + params = self._store_params if is_store else self._load_params + assert params is not None and self._queue is not None + self._queue.put( + (src_blocks, dst_blocks, params, is_store, event_idx, events_list) + ) + + def shutdown(self) -> None: + if self._shutdown: + return + self._shutdown = True + if self._queue is not None: + self._queue.put(None) + if self._thread is not None: + self._thread.join(timeout=5.0) + + @staticmethod + def _copy_loop( + q: queue.SimpleQueue, + device: torch.device, + load_stream: torch.cuda.Stream, + store_stream: torch.cuda.Stream, + ) -> None: + current_platform.set_device(device) + while True: + item = q.get() + if item is None: + return + src_blocks, dst_blocks, params, is_store, event_idx, events_list = item + copy_blocks(src_blocks, dst_blocks, params) + stream = store_stream if is_store else load_stream + event = torch.Event() + event.record(stream) + events_list.append((event_idx, event)) diff --git a/vllm/v1/simple_kv_offload/cuda_mem_ops.py b/vllm/v1/simple_kv_offload/cuda_mem_ops.py new file mode 100644 index 000000000..03338421c --- /dev/null +++ b/vllm/v1/simple_kv_offload/cuda_mem_ops.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Low-level CUDA memory helpers: pinning and batch DMA transfers.""" + +import ctypes +from typing import Any, NamedTuple + +import numpy as np +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def pin_tensor(tensor: torch.Tensor) -> None: + """Pin a CPU tensor via cudaHostRegister. + + This bypasses PyTorch's CUDACachingHostAllocator which rounds + every ``pin_memory=True`` allocation up to the next power of 2 + (e.g. 100 GB becomes 128 GB). + """ + err = torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.nbytes, 0) + if err.value != 0: + raise RuntimeError(f"cudaHostRegister failed: {err}") + + +class _CUmemLocation(ctypes.Structure): + _fields_ = [("type", ctypes.c_uint), ("id", ctypes.c_int)] + + +class _CUmemcpyAttributes(ctypes.Structure): + _fields_ = [ + ("srcAccessOrder", ctypes.c_uint), + ("srcLocHint", _CUmemLocation), + ("dstLocHint", _CUmemLocation), + ("flags", ctypes.c_uint), + ] + + +_BATCH_MEMCPY_FUNC_TYPE = ctypes.CFUNCTYPE( + ctypes.c_uint, # CUresult + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_void_p, + ctypes.c_void_p, +) + +# Resolved lazily on first use. +_batch_memcpy_fn: Any = None + + +def _resolve_batch_memcpy(): + """Resolve cuMemcpyBatchAsync via cuGetProcAddress (one-time).""" + from cuda.bindings import driver as drv + + err, ptr, _ = drv.cuGetProcAddress(b"cuMemcpyBatchAsync", 12080, 0) + if err != drv.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"cuGetProcAddress(cuMemcpyBatchAsync) failed: {err}") + return _BATCH_MEMCPY_FUNC_TYPE(ptr) + + +class BatchMemcpyParams(NamedTuple): + src_bases: np.ndarray # [num_layers] uint64 — data_ptr per layer + dst_bases: np.ndarray # [num_layers] uint64 + bpb: np.ndarray # [num_layers] uint64 — bytes per block + num_layers: int + attrs: _CUmemcpyAttributes + attrs_idx: ctypes.c_size_t + # NOTE: cuMemcpyBatchAsync_v2() removed fail_idx field, but we use + # cuMemcpyBatchAsync() with fail_idx for backward compatibility + fail_idx: ctypes.c_size_t + stream_handle: int # raw cudaStream_t / CUstream + + +def build_params( + src_caches: dict[str, torch.Tensor], + dst_caches: dict[str, torch.Tensor], + stream: torch.cuda.Stream, +) -> BatchMemcpyParams: + global _batch_memcpy_fn + if _batch_memcpy_fn is None: + _batch_memcpy_fn = _resolve_batch_memcpy() + + assert list(src_caches.keys()) == list(dst_caches.keys()) + src_tensors = list(src_caches.values()) + dst_tensors = list(dst_caches.values()) + + src_bases, dst_bases, bpb = [], [], [] + for s, d in zip(src_tensors, dst_tensors): + s_bpb = s.stride(0) * s.element_size() + assert s_bpb == d.stride(0) * d.element_size() + src_bases.append(s.data_ptr()) + dst_bases.append(d.data_ptr()) + bpb.append(s_bpb) + + # Refer to https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f for details. # noqa: E501 + attrs = _CUmemcpyAttributes(srcAccessOrder=3) # ANY + + return BatchMemcpyParams( + src_bases=np.array(src_bases, dtype=np.uint64), + dst_bases=np.array(dst_bases, dtype=np.uint64), + bpb=np.array(bpb, dtype=np.uint64), + num_layers=len(src_tensors), + attrs=attrs, + attrs_idx=ctypes.c_size_t(0), + fail_idx=ctypes.c_size_t(0), + stream_handle=stream.cuda_stream, + ) + + +def copy_blocks( + src_block_ids: list[int], + dst_block_ids: list[int], + params: BatchMemcpyParams, +) -> None: + """Copy blocks via cuMemcpyBatchAsync.""" + n = len(src_block_ids) + if n == 0: + return + + src_ids = np.array(src_block_ids, dtype=np.uint64) + dst_ids = np.array(dst_block_ids, dtype=np.uint64) + + src_all = ( + params.src_bases[:, None] + src_ids[None, :] * params.bpb[:, None] + ).ravel() + dst_all = ( + params.dst_bases[:, None] + dst_ids[None, :] * params.bpb[:, None] + ).ravel() + sz_all = np.repeat(params.bpb, n) + + total = n * params.num_layers + err = _batch_memcpy_fn( + dst_all.ctypes.data, + src_all.ctypes.data, + sz_all.ctypes.data, + total, + ctypes.addressof(params.attrs), + ctypes.byref(params.attrs_idx), + 1, + ctypes.byref(params.fail_idx), + params.stream_handle, + ) + if err != 0: + raise RuntimeError( + f"cuMemcpyBatchAsync failed: err={err} failIdx={params.fail_idx.value}" + ) diff --git a/vllm/v1/simple_kv_offload/manager.py b/vllm/v1/simple_kv_offload/manager.py new file mode 100644 index 000000000..5eedc07f7 --- /dev/null +++ b/vllm/v1/simple_kv_offload/manager.py @@ -0,0 +1,739 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Scheduler-side manager for SimpleCPUOffloadConnector.""" + +import contextlib +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from vllm.config import VllmConfig +from vllm.distributed.kv_events import KVCacheEvent +from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data +from vllm.logger import init_logger +from vllm.utils.math_utils import cdiv +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_coordinator import ( + KVCacheCoordinator, + get_kv_cache_coordinator, +) +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + MambaSpec, + SlidingWindowSpec, +) +from vllm.v1.outputs import KVConnectorOutput +from vllm.v1.simple_kv_offload.metadata import ( + SimpleCPUOffloadMetadata, + SimpleCPUOffloadWorkerMetadata, +) + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.core.kv_cache_utils import KVCacheBlock + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class TransferMeta: + gpu_block_ids: list[int] + cpu_block_ids: list[int] + + +@dataclass +class LoadRequestState: + request: "Request" + transfer_meta: TransferMeta + load_event: int | None = None + finished: bool = False + + +# NOTE: This per-request state is only used in eager mode. +@dataclass +class StoreRequestState: + request: "Request" + # Accumulated block IDs from scheduler_output via yield_req_data. + block_ids: tuple[list[int], ...] + # Per-group cursors tracking how many blocks have been stored/skipped. + num_stored_blocks: list[int] + store_events: set[int] = field(default_factory=set) + finished: bool = False + + +class SimpleCPUOffloadScheduler: + """Scheduler-side manager for CPU offloading.""" + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: "KVCacheConfig | None", + cpu_capacity_bytes: int, + lazy_offload: bool = False, + ): + self.vllm_config = vllm_config + self.kv_cache_config = kv_cache_config + self.enable_kv_cache_events = ( + vllm_config.kv_events_config is not None + and vllm_config.kv_events_config.enable_kv_cache_events + ) + # NOTE: We use the same block size for both GPU and CPU. + self.block_size = vllm_config.cache_config.block_size + # Derive a CPU KVCacheConfig from the GPU config and build a coordinator + assert kv_cache_config is not None + self.cpu_kv_cache_config = self._derive_cpu_config( + kv_cache_config, cpu_capacity_bytes + ) + self.num_cpu_blocks = self.cpu_kv_cache_config.num_blocks + # Find the full attention kv group for prefix cache matching. + self.fa_gidx = -1 + for g_idx, g in enumerate(self.cpu_kv_cache_config.kv_cache_groups): + if isinstance(g.kv_cache_spec, FullAttentionSpec): + self.fa_gidx = g_idx + break + assert 0 <= self.fa_gidx < len(self.cpu_kv_cache_config.kv_cache_groups) + + logger.info( + "SimpleCPUOffloadScheduler: Allocating %d CPU blocks (%.2f GB, mode=%s)", + self.num_cpu_blocks, + cpu_capacity_bytes / (1024**3), + "lazy" if lazy_offload else "eager", + ) + + # TODO (yifan): maybe need to enable kv_cache_events and metrics_collector here. + dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size + pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size + assert dcp_world_size == 1 and pcp_world_size == 1 + self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator( + kv_cache_config=self.cpu_kv_cache_config, + max_model_len=vllm_config.model_config.max_model_len, + use_eagle=False, + enable_caching=True, + enable_kv_cache_events=self.enable_kv_cache_events, + dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, + hash_block_size=self.block_size, + ) + self.cpu_block_pool: BlockPool = self.cpu_coordinator.block_pool + + # GPU block pool reference - bound after scheduler builds kv_cache_manager + self._gpu_block_pool: BlockPool | None = None + + # Load metadata + self._reqs_to_load: dict[str, LoadRequestState] = {} + # Inverse map: load_event_idx -> req_ids. Keyed by load_event_idx because + # the worker reports completions by event index, not request id. + self._load_event_to_reqs: dict[int, list[str]] = {} + + # Store metadata + self._lazy_mode = lazy_offload + # Lazy mode: use a cursor to track the last scanned block in the GPU free queue. + self._cursor: KVCacheBlock | None = None + if self._lazy_mode: + self._target_free = self._estimate_lazy_target_blocks( + kv_cache_config, + vllm_config.scheduler_config.max_num_batched_tokens, + ) + else: + self._target_free = 0 + self._store_event_to_blocks: dict[int, TransferMeta] = {} + # Eager mode only + self._reqs_to_store: dict[str, StoreRequestState] = {} + self._store_event_to_reqs: dict[int, list[str]] = {} + + # Event counters + self._load_event_counter: int = 0 + self._store_event_counter: int = 0 + + # For TP/PP: track partial store completions across steps. + # Events must be reported by all world_size workers before considered complete. + self._expected_worker_count = vllm_config.parallel_config.world_size + self._store_event_pending_counts: dict[int, int] = {} + + @staticmethod + def _derive_cpu_config( + gpu_config: "KVCacheConfig", cpu_capacity_bytes: int + ) -> "KVCacheConfig": + """Derive a CPU KVCacheConfig from the GPU config. + Same kv_cache_groups, num_blocks scaled by CPU/GPU memory ratio.""" + # Import here to avoid potential circular imports + from vllm.v1.kv_cache_interface import KVCacheConfig as KVCacheConfigCls + from vllm.v1.kv_cache_interface import KVCacheTensor + + assert len(gpu_config.kv_cache_tensors) > 0 + + gpu_total_bytes = sum(t.size for t in gpu_config.kv_cache_tensors) + num_gpu_blocks = gpu_config.num_blocks + num_cpu_blocks = max(1, num_gpu_blocks * cpu_capacity_bytes // gpu_total_bytes) + # Create CPU kv_cache_tensors mirroring GPU by scaling size proportionally. + cpu_tensors = [ + KVCacheTensor( + size=t.size // num_gpu_blocks * num_cpu_blocks, + shared_by=list(t.shared_by), + ) + for t in gpu_config.kv_cache_tensors + ] + + return KVCacheConfigCls( + num_blocks=num_cpu_blocks, + kv_cache_tensors=cpu_tensors, + kv_cache_groups=gpu_config.kv_cache_groups, + ) + + @staticmethod + def _estimate_lazy_target_blocks( + kv_cache_config: "KVCacheConfig", max_num_batched_tokens: int + ) -> int: + """GPU blocks to keep available (free/offloaded) per step in lazy mode.""" + WATERMARK_RATIO = 1.0 # Reserve larger space to avoid running out of GPU blocks + target = 0 + for g in kv_cache_config.kv_cache_groups: + spec = g.kv_cache_spec + if isinstance(spec, MambaSpec): + target += 2 + elif isinstance(spec, SlidingWindowSpec): + target += cdiv(spec.sliding_window, spec.block_size) + 1 + else: + target += cdiv(max_num_batched_tokens, spec.block_size) + return int(target * (1 + WATERMARK_RATIO)) + + def bind_gpu_block_pool(self, gpu_block_pool: BlockPool) -> None: + """Bind GPU block pool so that we can touch blocks during stores. + Called by Scheduler after kv_cache_manager is ready.""" + self._gpu_block_pool = gpu_block_pool + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int | None, bool]: + """Return (num_new_tokens, is_async) from consecutive CPU cache hits.""" + skipped = num_computed_tokens // self.block_size + remaining_hashes = request.block_hashes[skipped:] + + if not remaining_hashes: + return 0, False + # Must recompute at least the last token, matching the logic in + # kv_cache_manager.get_computed_blocks(). + max_hit_len = request.num_tokens - 1 - num_computed_tokens + if max_hit_len <= 0: + return 0, False + _, hit_length = self.cpu_coordinator.find_longest_cache_hit( + remaining_hashes, max_hit_len + ) + + if hit_length > 0: + return hit_length, True + return 0, False + + # TODO(yifan): this API now only matches the suffix part of the prefix cache. A more + # general API should scan blocks in both GPU and CPU block pool in a single pass. + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + ) -> None: + req_id = request.request_id + block_ids_by_group = blocks.get_block_ids() + num_groups = len(block_ids_by_group) + + # Store tracking (eager mode only). Register the request; + # block IDs are accumulated from scheduler_output in + # _prepare_eager_store_specs via yield_req_data. + if not self._lazy_mode and req_id not in self._reqs_to_store: + self._reqs_to_store[req_id] = StoreRequestState( + request=request, + block_ids=tuple([] for _ in range(num_groups)), + num_stored_blocks=[0] * num_groups, + ) + + if num_external_tokens == 0: + return + + num_blocks_to_load = num_external_tokens // self.block_size + assert num_blocks_to_load > 0 + + skipped = sum(blk.block_hash is not None for blk in blocks.blocks[self.fa_gidx]) + num_computed_tokens = skipped * self.block_size + hashes_to_load = request.block_hashes[skipped : skipped + num_blocks_to_load] + + # Find CPU cached blocks across all groups. + max_hit_len = len(hashes_to_load) * self.block_size + cpu_hit_blocks, hit_length = self.cpu_coordinator.find_longest_cache_hit( + hashes_to_load, max_hit_len + ) + assert hit_length == num_external_tokens, ( + f"Expected {num_external_tokens} hit tokens, got {hit_length}" + ) + + # Build transfer pairs across all groups. + total_computed_tokens = num_computed_tokens + num_external_tokens + kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups + + gpu_block_ids: list[int] = [] + cpu_block_ids: list[int] = [] + cpu_blocks_to_touch: list[KVCacheBlock] = [] + + for g in range(num_groups): + cpu_blocks_g = cpu_hit_blocks[g] + n_ext_g = len(cpu_blocks_g) + if n_ext_g == 0: + continue + + # Number of blocks in the computed range for this group. + g_block_size = kv_cache_groups[g].kv_cache_spec.block_size + n_computed_g = cdiv(total_computed_tokens, g_block_size) + + # Back-trace: ext blocks sit at the tail of the computed range. + gpu_ext_start = n_computed_g - n_ext_g + group_gpu_ids = block_ids_by_group[g] + + for i, cpu_blk in enumerate(cpu_blocks_g): + # Skip null blocks (e.g. sliding window or mamba padding). + if cpu_blk.is_null: + continue + gpu_block_ids.append(group_gpu_ids[gpu_ext_start + i]) + cpu_block_ids.append(cpu_blk.block_id) + cpu_blocks_to_touch.append(cpu_blk) + + # Touch CPU blocks to prevent eviction during async load. + self.cpu_block_pool.touch(cpu_blocks_to_touch) + + # Touch GPU blocks to prevent freeing during async load + assert self._gpu_block_pool is not None + self._gpu_block_pool.touch( + [self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids] + ) + + assert self._reqs_to_load.get(req_id) is None + self._reqs_to_load[req_id] = LoadRequestState( + request=request, transfer_meta=TransferMeta(gpu_block_ids, cpu_block_ids) + ) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> SimpleCPUOffloadMetadata: + # --- Stores --- + store_event = -1 + store_gpu, store_cpu, store_req_ids = self.prepare_store_specs(scheduler_output) + if store_gpu: + store_event = self._store_event_counter + self._store_event_counter += 1 + self._store_event_to_blocks[store_event] = TransferMeta( + store_gpu, store_cpu + ) + if store_req_ids: # For eager mode only, track req->blocks mapping + self._store_event_to_reqs[store_event] = store_req_ids + for req_id in store_req_ids: + store_state = self._reqs_to_store.get(req_id) + if store_state is not None: + store_state.store_events.add(store_event) + + # --- Loads --- + load_event = -1 + load_gpu: list[int] = [] + load_cpu: list[int] = [] + load_req_ids: list[str] = [] + for req_id, load_state in self._reqs_to_load.items(): + if load_state.load_event is not None: + continue + assert load_state.transfer_meta is not None + load_gpu.extend(load_state.transfer_meta.gpu_block_ids) + load_cpu.extend(load_state.transfer_meta.cpu_block_ids) + load_req_ids.append(req_id) + if load_req_ids: + load_event = self._load_event_counter + self._load_event_counter += 1 + for req_id in load_req_ids: + self._reqs_to_load[req_id].load_event = load_event + self._load_event_to_reqs[load_event] = load_req_ids + + result = SimpleCPUOffloadMetadata( + load_event=load_event, + load_gpu_blocks=load_gpu, + load_cpu_blocks=load_cpu, + load_event_to_reqs=self._load_event_to_reqs, + store_event=store_event, + store_gpu_blocks=store_gpu, + store_cpu_blocks=store_cpu, + need_flush=bool(scheduler_output.preempted_req_ids), + ) + return result + + def prepare_store_specs( + self, scheduler_output: SchedulerOutput + ) -> tuple[list[int], list[int], list[str]]: + """Prepare store specs for the store event.""" + if self._lazy_mode: + return self._prepare_lazy_store_specs() + else: + return self._prepare_eager_store_specs(scheduler_output) + + def _prepare_lazy_store_specs( + self, + ) -> tuple[list[int], list[int], list[str]]: + """Single-pass cursor walk: offload cached GPU blocks near eviction. + + Walks the GPU free queue from the cursor, counting blocks that are + free-or-offloaded (safe for the allocator to evict). Stops when + target_free blocks are covered or CPU capacity is reached. + """ + gpu_pool = self._gpu_block_pool + if gpu_pool is None or self._target_free <= 0: + return [], [], [] + + free_queue = gpu_pool.free_block_queue + cpu_pool = self.cpu_block_pool + num_cpu_free = cpu_pool.get_num_free_blocks() + + # Validate cursor: stale if block was removed from free queue. + if self._cursor is not None and self._cursor.ref_cnt > 0: + self._cursor = None + + # Determine start node. + if self._cursor is None: + node = free_queue.fake_free_list_head.next_free_block + else: + node = self._cursor.next_free_block + + tail = free_queue.fake_free_list_tail + gpu_ids: list[int] = [] + block_hashes: list[bytes] = [] + covered = 0 + last_visited = self._cursor + + while ( + node is not None + and node is not tail + and covered < self._target_free + and len(gpu_ids) < num_cpu_free + ): + last_visited = node + bhash = node.block_hash + + if ( + bhash is not None + and not node.is_null + and cpu_pool.cached_block_hash_to_block.get_one_block(bhash) is None + ): + gpu_ids.append(node.block_id) + block_hashes.append(bhash) + + covered += 1 + node = node.next_free_block + + self._cursor = last_visited + + # Batch-allocate CPU blocks and stamp hashes. + if gpu_ids: + cpu_blocks = cpu_pool.get_new_blocks(len(gpu_ids)) + cpu_ids = [blk.block_id for blk in cpu_blocks] + for cpu_blk, bhash in zip(cpu_blocks, block_hashes): # type: ignore[assignment] + cpu_blk._block_hash = bhash # type: ignore[assignment] + # Touch GPU blocks to prevent eviction during async copy. + gpu_pool.touch([gpu_pool.blocks[bid] for bid in gpu_ids]) + else: + cpu_ids = [] + + return gpu_ids, cpu_ids, [] + + def _prepare_eager_store_specs( + self, scheduler_output: SchedulerOutput + ) -> tuple[list[int], list[int], list[str]]: + """Identify newly computed blocks to offload from scheduler requests. + + Only considers blocks whose KV data has been **confirmed computed** by + the GPU. This means blocks from the current step are NOT stored until the + next step. If a request finishes in the same step as its last full block, + that block may be missed. (TODO: flush on finish.) + + Returns: + (gpu_block_ids, cpu_block_ids, req_ids) for the store event. + """ + + merged_gpu_block_ids: list[int] = [] + merged_cpu_block_ids: list[int] = [] + req_ids: list[str] = [] + + gpu_block_pool = self._gpu_block_pool + if gpu_block_pool is None: + return [], [], [] + cpu_block_pool = self.cpu_block_pool + num_free = cpu_block_pool.get_num_free_blocks() + kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups + num_groups = len(kv_cache_groups) + gpu_blocks_this_step: set[int] = set() + + for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): + state = self._reqs_to_store.get(req_id) + if state is None or state.finished: + continue + + # Accumulate new block IDs. + if preempted: + state.block_ids = tuple([] for _ in range(num_groups)) + state.num_stored_blocks = [0] * num_groups + if new_block_id_groups: + for g in range(min(num_groups, len(new_block_id_groups))): + if new_block_id_groups[g] is not None: + state.block_ids[g].extend(new_block_id_groups[g]) + + num_new_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) + if num_new_tokens == 0: + continue + + block_ids_by_group = state.block_ids + if not block_ids_by_group: + continue + + # --- Phase 1: Scan blocks, classify as cached vs to-store --- + gpu_block_ids: list[int] = [] + block_hashes_to_store: list[bytes] = [] + advanced_per_group: list[int] = [0] * num_groups + out_of_space = False + # Confirmed tokens: KV data written and visible to all streams. + req = state.request + confirmed_tokens = req.num_computed_tokens - req.num_output_placeholders + + for g in range(num_groups): + # FIXME (yifan): handle CPU cache eviction, where + # num_stored_blocks can be stale and omit evicted blocks in + # the middle of the request. + already_stored_g = state.num_stored_blocks[g] + group_gpu_ids = block_ids_by_group[g] + + # Cap to blocks with confirmed KV data. + g_block_size = kv_cache_groups[g].kv_cache_spec.block_size + ready_blocks_g = confirmed_tokens // g_block_size + scannable = group_gpu_ids[already_stored_g:ready_blocks_g] + + for gpu_block_id in scannable: + gpu_block = gpu_block_pool.blocks[gpu_block_id] + if gpu_block.is_null: + advanced_per_group[g] += 1 + continue + + bhash_with_group = gpu_block.block_hash + if bhash_with_group is None: + break + + # Check if this group's data is already scheduled for store + # in this step or already cached in CPU. + if ( + gpu_block_id in gpu_blocks_this_step + or cpu_block_pool.cached_block_hash_to_block.get_one_block( + bhash_with_group + ) + is not None + ): + advanced_per_group[g] += 1 + continue + + if num_free <= 0: + out_of_space = True + break + num_free -= 1 + + gpu_block_ids.append(gpu_block_id) + block_hashes_to_store.append(bhash_with_group) + advanced_per_group[g] += 1 + + if out_of_space: + break + + # --- Phase 2: Batch allocate CPU blocks and stamp hashes --- + n_to_alloc = len(gpu_block_ids) + if n_to_alloc > 0: + cpu_blocks_alloc = cpu_block_pool.get_new_blocks(n_to_alloc) + cpu_block_ids = [blk.block_id for blk in cpu_blocks_alloc] + for cpu_blk, bhash in zip(cpu_blocks_alloc, block_hashes_to_store): + cpu_blk._block_hash = bhash # type: ignore[assignment] + else: + cpu_block_ids = [] + + if cpu_block_ids: + req_ids.append(req_id) + merged_gpu_block_ids.extend(gpu_block_ids) + merged_cpu_block_ids.extend(cpu_block_ids) + gpu_blocks_this_step.update(gpu_block_ids) + + # Touch GPU blocks to prevent freeing during async copy + gpu_block_pool.touch( + [gpu_block_pool.blocks[bid] for bid in gpu_block_ids] + ) + + logger.debug( + "Request %s: Scheduling store of %d blocks to CPU (%d groups)", + req_id, + len(cpu_block_ids), + num_groups, + ) + + # Advance per-group cursors (includes cached hits + newly stored) + for g in range(num_groups): + state.num_stored_blocks[g] += advanced_per_group[g] + + return merged_gpu_block_ids, merged_cpu_block_ids, req_ids + + def update_connector_output(self, connector_output: KVConnectorOutput) -> None: + """Handle async transfer completions from worker. + + Load completions arrive via finished_recving (real req_ids). + Store completions arrive via kv_connector_worker_meta as + per-event worker counts. We accumulate across steps and process + a store event only when all workers have reported completion. + """ + # --- Load completions --- + for req_id in list(connector_output.finished_recving or []): + self._cleanup_load_request(req_id) + + # --- Store completions --- + meta = connector_output.kv_connector_worker_meta + if not isinstance(meta, SimpleCPUOffloadWorkerMetadata): + return + for event_idx, count in meta.completed_store_events.items(): + total = self._store_event_pending_counts.get(event_idx, 0) + count + if total >= self._expected_worker_count: + self._store_event_pending_counts.pop(event_idx, None) + self._process_store_event(event_idx) + else: + self._store_event_pending_counts[event_idx] = total + + def _process_store_event(self, event_idx: int) -> None: + """Process a fully-completed store event.""" + transfer = self._store_event_to_blocks.pop(event_idx) + self._process_store_completion(transfer.gpu_block_ids, transfer.cpu_block_ids) + logger.debug( + "Store event %d completed: cached %d blocks to CPU", + event_idx, + len(transfer.cpu_block_ids), + ) + + # Eager only: update per-req state + if not self._lazy_mode: + for req_id in self._store_event_to_reqs.pop(event_idx, []): + state = self._reqs_to_store.get(req_id) + if state is None: + continue + state.store_events.discard(event_idx) + if state.finished and not state.store_events: + self._cleanup_store_request(req_id) + + def _process_store_completion( + self, gpu_block_ids: list[int], cpu_block_ids: list[int] + ) -> None: + """Cache CPU blocks per-group and release GPU refs. + + Block hashes were stamped on CPU blocks at allocation time (in + ``_prepare_*_store_specs``). Here we just register them in the + cache map so they become discoverable by the load path. + """ + assert len(cpu_block_ids) == len(gpu_block_ids) + + cpu_blocks = [self.cpu_block_pool.blocks[bid] for bid in cpu_block_ids] + + for cpu_block in cpu_blocks: + bhash = cpu_block.block_hash + assert bhash is not None + self.cpu_block_pool.cached_block_hash_to_block.insert(bhash, cpu_block) + + # Free CPU and GPU blocks' ref counts to turn them into prefix cache + self.cpu_block_pool.free_blocks(cpu_blocks) + assert self._gpu_block_pool is not None + self._gpu_block_pool.free_blocks( + self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids + ) + + def has_pending_stores(self) -> bool: + """Return True if there are in-flight store transfers.""" + return bool(self._store_event_to_blocks) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + """Always returns (False, None). GPU blocks are protected by ref_cnt, + so the scheduler can free blocks immediately.""" + req_id = request.request_id + + # Handle load: defer cleanup if load is in-flight + load_state = self._reqs_to_load.get(req_id) + if load_state is not None: + if load_state.load_event is not None: + load_state.finished = True # Defer: load in-flight + else: + self._cleanup_load_request(req_id) + + # Handle store (eager mode only): defer cleanup if stores in-flight + if not self._lazy_mode: + store_state = self._reqs_to_store.get(req_id) + if store_state is not None: + if store_state.store_events: + store_state.finished = True # Defer: stores in-flight + else: + self._cleanup_store_request(req_id) + + return False, None + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + return self.request_finished(request, block_ids=[]) + + def _cleanup_load_request(self, req_id: str) -> None: + """Release all load resources for a request. + + Shared between request_finished() and update_connector_output() paths. + Removes the request from _reqs_to_load, cleans up event mappings, + and frees CPU/GPU touch refs. + """ + state = self._reqs_to_load.pop(req_id, None) + if state is None: + return + # Remove from load event mapping (only this req, not whole event) + if state.load_event is not None: + reqs = self._load_event_to_reqs.get(state.load_event) + if reqs is not None: + with contextlib.suppress(ValueError): + reqs.remove(req_id) + if not reqs: + self._load_event_to_reqs.pop(state.load_event, None) + + if state.transfer_meta is not None: + # Free CPU touch refs + self.cpu_block_pool.free_blocks( + self.cpu_block_pool.blocks[bid] + for bid in state.transfer_meta.cpu_block_ids + ) + # Free GPU touch refs + assert self._gpu_block_pool is not None + self._gpu_block_pool.free_blocks( + self._gpu_block_pool.blocks[bid] + for bid in state.transfer_meta.gpu_block_ids + ) + + def _cleanup_store_request(self, req_id: str) -> None: + """Release store metadata for a request. + + Metadata-only cleanup but no block freeing. Job completion handles + block caching and GPU ref freeing via _process_store_completion(). + """ + state = self._reqs_to_store.pop(req_id, None) + if state is None: + return + for event_idx in list(state.store_events): + if (reqs := self._store_event_to_reqs.get(event_idx)) is not None: + with contextlib.suppress(ValueError): + reqs.remove(req_id) + if not reqs: + self._store_event_to_reqs.pop(event_idx, None) + state.store_events.clear() + + def take_events(self) -> Iterable[KVCacheEvent]: + return self.cpu_block_pool.take_events() diff --git a/vllm/v1/simple_kv_offload/metadata.py b/vllm/v1/simple_kv_offload/metadata.py new file mode 100644 index 000000000..8c8d4511e --- /dev/null +++ b/vllm/v1/simple_kv_offload/metadata.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Metadata for SimpleCPUOffloadConnector.""" + +from dataclasses import dataclass, field + +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, + KVConnectorWorkerMetadata, +) + +INVALID_JOB_ID = -1 + + +@dataclass +class SimpleCPUOffloadMetadata(KVConnectorMetadata): + """ + Metadata passed from scheduler to worker for CPU offload operations. + + The worker receives flat block lists keyed by a monotonic event_idx. + Job->req_id translation is handled by the scheduler-side manager + (via inverse maps), so the worker never knows about request identities. + """ + + # Load event per step. INVALID_JOB_ID means no blocks to load this step. + load_event: int = INVALID_JOB_ID + load_gpu_blocks: list[int] = field(default_factory=list) + load_cpu_blocks: list[int] = field(default_factory=list) + # Reverse map: load_event->req_ids, for tracking requests with finished load events + load_event_to_reqs: dict[int, list[str]] = field(default_factory=dict) + + # Store event per step. INVALID_JOB_ID means no blocks to store this step. + store_event: int = INVALID_JOB_ID + store_gpu_blocks: list[int] = field(default_factory=list) + store_cpu_blocks: list[int] = field(default_factory=list) + + # Whether any requests were preempted this step and need flush pending transfers. + need_flush: bool = False + + +@dataclass +class SimpleCPUOffloadWorkerMetadata(KVConnectorWorkerMetadata): + """Worker -> Scheduler metadata for completed store events. + + Each worker reports {event_idx: 1} for newly completed stores. + ``aggregate()`` sums counts across workers within a step. + The scheduler-side manager accumulates across steps and processes + a store completion only when count reaches ``world_size``. + """ + + completed_store_events: dict[int, int] + + def aggregate( + self, other: "KVConnectorWorkerMetadata" + ) -> "KVConnectorWorkerMetadata": + assert isinstance(other, SimpleCPUOffloadWorkerMetadata) + merged = dict(self.completed_store_events) + for k, v in other.completed_store_events.items(): + merged[k] = merged.get(k, 0) + v + return SimpleCPUOffloadWorkerMetadata(completed_store_events=merged) diff --git a/vllm/v1/simple_kv_offload/worker.py b/vllm/v1/simple_kv_offload/worker.py new file mode 100644 index 000000000..c23b44f29 --- /dev/null +++ b/vllm/v1/simple_kv_offload/worker.py @@ -0,0 +1,305 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Worker-side handler for SimpleCPUOffloadConnector.""" + +from typing import TYPE_CHECKING + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.v1.simple_kv_offload.copy_backend import DmaCopyBackend +from vllm.v1.simple_kv_offload.cuda_mem_ops import pin_tensor +from vllm.v1.simple_kv_offload.metadata import ( + SimpleCPUOffloadMetadata, + SimpleCPUOffloadWorkerMetadata, +) + +if TYPE_CHECKING: + from vllm.v1.kv_cache_interface import KVCacheConfig + +logger = init_logger(__name__) + + +class SimpleCPUOffloadWorker: + """Worker-side handler for CPU offloading transfers.""" + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: "KVCacheConfig | None", + cpu_capacity_bytes: int, + ): + self.vllm_config = vllm_config + self.kv_cache_config = kv_cache_config + self.cpu_capacity_bytes = cpu_capacity_bytes + + self.gpu_kv_caches: dict[str, torch.Tensor] | None = None + self.cpu_kv_caches: dict[str, torch.Tensor] | None = None + self.device: torch.device | None = None + self.num_cpu_blocks: int = 0 + + # CUDA streams for the async transfers + self.load_stream: torch.cuda.Stream | None = None + self.store_stream: torch.cuda.Stream | None = None + + self._backend = DmaCopyBackend() + + # Ordered (event_idx, Event). Events pre-allocated on main thread. + self._load_events: list[tuple[int, torch.Event]] = [] + self._store_events: list[tuple[int, torch.Event]] = [] + # High-water marks: highest event_idx completed per stream. + # When the event list is empty, the hwm covers all prior events. + self._load_hwm: int = -1 + self._store_hwm: int = -1 + + # Metadata for the current step + self._connector_metadata: SimpleCPUOffloadMetadata | None = None + + # Pending event index sets, populated in bind_connector_metadata + self._pending_load_event_indices: set[int] = set() + self._pending_store_event_indices: set[int] = set() + # Completed store events to report via build_connector_worker_meta + self._completed_store_events: dict[int, int] = {} + + def register_kv_caches( + self, + kv_caches: dict[str, torch.Tensor], + ) -> None: + """Register GPU KV caches and allocate pinned CPU tensors. + The worker will infer the underlying raw storage from the kv_caches. + + Args: + kv_caches: Per-layer GPU KV caches. Values are either a single + tensor (attention layers) or a list of tensors (Mamba layers + in hybrid models). All values are included for offloading + by resolving to their underlying raw storage. + """ + if not kv_caches: + logger.warning("No KV caches to offload.") + return + + # Resolve each entry to a representative tensor for storage + # deduplication. For attention layers the value is already a tensor; + # for Mamba layers it is a list of tensors that all share the same + # underlying raw storage, so we take the first one. + def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor: + assert isinstance(v, torch.Tensor | list) + return v if isinstance(v, torch.Tensor) else v[0] + + any_tensor = _repr_tensor(next(iter(kv_caches.values()))) + self.device = any_tensor.device + + assert self.kv_cache_config is not None + num_blocks = self.kv_cache_config.num_blocks + + # Deduplicate: multiple layers may share the same backing storage. + seen_ptrs: dict[int, tuple[str, torch.Tensor]] = {} + for name, value in kv_caches.items(): + tensor = _repr_tensor(value) + ptr = tensor.untyped_storage().data_ptr() + if ptr not in seen_ptrs: + seen_ptrs[ptr] = (name, tensor) + + # Build [num_blocks, block_bytes] int8 views from each unique + # storage so that stride(0) gives block_bytes for the copy op. + # + # The physical layout varies across attention backends: + # FlashAttn/ROCm: (2, num_blocks, ...) -> K/V outermost, 2 segments + # FlashInfer/MLA: (num_blocks, ...) -> blocks outermost, 1 segment + # We derive page_size_bytes = storage.nbytes() // num_blocks, then + # classify dims: any dim whose byte-stride exceeds page_size_bytes + # must be an outer segment dim (e.g. the K/V dim of size 2). A less + # hacky way is to update the interface with the layout. + unique_gpu_caches: dict[str, torch.Tensor] = {} + for name, tensor in seen_ptrs.values(): + storage = tensor.untyped_storage() + raw = torch.empty(0, dtype=torch.int8, device=self.device).set_( + storage, 0, (storage.nbytes(),) + ) + el = tensor.element_size() + page_size_bytes = storage.nbytes() // num_blocks + outer_dims = [ + d for d in range(tensor.ndim) if tensor.stride(d) * el > page_size_bytes + ] + if not outer_dims: + unique_gpu_caches[name] = raw.view(num_blocks, -1) + else: + seg_stride = tensor.stride(outer_dims[0]) * el + for idx in range(tensor.shape[outer_dims[0]]): + offset = idx * seg_stride + chunk = raw[offset : offset + seg_stride] + unique_gpu_caches[f"{name}.{idx}"] = chunk.view(num_blocks, -1) + + # Compute per-tensor bytes_per_block. Tensors may have different + # page_size_bytes (e.g., UniformTypeKVCacheSpecs with varying head_size). + per_tensor_bpb = [ + t.stride(0) * t.element_size() for t in unique_gpu_caches.values() + ] + total_bytes_per_block = sum(per_tensor_bpb) + + self.num_cpu_blocks = max(1, self.cpu_capacity_bytes // total_bytes_per_block) + + logger.info( + "SimpleCPUOffloadWorker: %d unique GPU KV tensors, " + "allocating %d CPU blocks (%.2f GB)", + len(unique_gpu_caches), + self.num_cpu_blocks, + (self.num_cpu_blocks * total_bytes_per_block) / (1024**3), + ) + + pin_memory = is_pin_memory_available() + if not pin_memory: + logger.warning( + "Pinned memory not available. CPU offload performance may be degraded." + ) + + self.gpu_kv_caches = unique_gpu_caches + self.cpu_kv_caches = {} + for name, gpu_tensor in unique_gpu_caches.items(): + cpu_shape = (self.num_cpu_blocks,) + gpu_tensor.shape[1:] + # Allocate non-pinned first, then pin via cudaHostRegister to + # bypass PyTorch's CUDACachingHostAllocator which rounds up to + # the next power of 2 (e.g. 100 GB -> 128 GB). + tensor = torch.zeros(cpu_shape, dtype=gpu_tensor.dtype, device="cpu") + if pin_memory: + pin_tensor(tensor) + self.cpu_kv_caches[name] = tensor + + # Use lowest priority so KV cache I/O yields to compute streams. + low_pri, _ = torch.cuda.Stream.priority_range() + self.load_stream = torch.cuda.Stream(priority=low_pri) + self.store_stream = torch.cuda.Stream(priority=low_pri) + + # Initialize copy backend with caches and streams. + self._backend.init( + self.gpu_kv_caches, + self.cpu_kv_caches, + self.device, + self.load_stream, + self.store_stream, + ) + + def bind_connector_metadata(self, metadata: SimpleCPUOffloadMetadata) -> None: + self._connector_metadata = metadata + if metadata.load_event >= 0: + self._pending_load_event_indices.add(metadata.load_event) + if metadata.store_event >= 0: + self._pending_store_event_indices.add(metadata.store_event) + + def clear_connector_metadata(self) -> None: + self._connector_metadata = None + + def start_load_kv(self) -> None: + # NOTE: we defer launching both load and store to get_finished(), + # which runs after model execution. This hides the CPU-side + # block copy op overhead (~5ms) behind GPU compute. + pass + + def wait_for_save(self) -> None: + pass + + def get_finished( + self, + finished_req_ids: set[str], + ) -> tuple[set[str] | None, set[str] | None]: + """Submit transfers and report completed events to the scheduler. + + Called after model execution. The manager only schedules stores for + blocks whose KV data is confirmed computed, so we launch both loads + and stores immediately — no deferral or cross-stream sync needed. + + Returns: + tuple of (finished_sending, finished_recving). + - finished_sending: always None (stores use worker metadata). + - finished_recving: req_ids whose loads have completed. + """ + # (1) Submit transfers + metadata = self._connector_metadata + if metadata is not None: + # Launch loads (CPU->GPU). + if metadata.load_cpu_blocks: + self._backend.launch_copy( + metadata.load_cpu_blocks, + metadata.load_gpu_blocks, + is_store=False, + event_idx=metadata.load_event, + events_list=self._load_events, + ) + # Launch stores (GPU->CPU). + if metadata.store_gpu_blocks: + self._backend.launch_copy( + metadata.store_gpu_blocks, + metadata.store_cpu_blocks, + is_store=True, + event_idx=metadata.store_event, + events_list=self._store_events, + ) + + # (2) Track completed transfer events + finished_recving: set[str] = set() + + if self._pending_load_event_indices: + load_wm = self._poll_stream_events(is_store=False) + for j in [j for j in self._pending_load_event_indices if j <= load_wm]: + self._pending_load_event_indices.discard(j) + req_ids = ( + metadata.load_event_to_reqs.get(j) if metadata is not None else None + ) + if req_ids: + finished_recving.update(req_ids) + + if self._pending_store_event_indices: + store_wm = self._poll_stream_events(is_store=True) + for j in [j for j in self._pending_store_event_indices if j <= store_wm]: + self._pending_store_event_indices.discard(j) + self._completed_store_events[j] = 1 + + return None, finished_recving or None + + def build_connector_worker_meta(self) -> SimpleCPUOffloadWorkerMetadata | None: + """Return completed store events since the last call.""" + if not self._completed_store_events: + return None + meta = SimpleCPUOffloadWorkerMetadata( + completed_store_events=self._completed_store_events, + ) + self._completed_store_events = {} + return meta + + def handle_preemptions( + self, kv_connector_metadata: SimpleCPUOffloadMetadata + ) -> None: + """Sync all in-flight transfers before preempted blocks are reused.""" + if not kv_connector_metadata.need_flush: + return + self._flush_and_sync_all() + + def _flush_and_sync_all(self) -> None: + """Synchronize all in-flight transfer events.""" + for event_idx, event in self._load_events: + event.synchronize() + self._load_hwm = event_idx + self._load_events.clear() + + for event_idx, event in self._store_events: + event.synchronize() + self._store_hwm = event_idx + self._store_events.clear() + + def _poll_stream_events(self, is_store: bool) -> int: + """Non-blocking poll for completed events and return the high-water mark.""" + events = self._store_events if is_store else self._load_events + hwm = self._store_hwm if is_store else self._load_hwm + while events: + event_idx, event = events[0] + if not event.query(): + break + hwm = event_idx + events.pop(0) + if is_store: + self._store_hwm = hwm + else: + self._load_hwm = hwm + return hwm