[kv_offload+HMA][6/N]: Split offloading_connector.py (#37405)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-03-18 15:42:46 +02:00
committed by GitHub
parent 918b7890a1
commit 525f2eeb0b
7 changed files with 733 additions and 658 deletions

View File

@@ -13,10 +13,14 @@ from vllm import SamplingParams
from vllm.config import KVTransferConfig, VllmConfig from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_events import BlockRemoved, BlockStored from vllm.distributed.kv_events import BlockRemoved, BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
OffloadingConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import (
OffloadingConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
OffloadingConnector, OffloadingConnector,
OffloadingConnectorMetadata,
OffloadingConnectorStats,
) )
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.utils.hashing import sha256 from vllm.utils.hashing import sha256

View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.v1.kv_offload.worker.worker import TransferSpec
ReqId = str
@dataclass
class OffloadingConnectorMetadata(KVConnectorMetadata):
reqs_to_load: dict[ReqId, TransferSpec]
reqs_to_store: dict[ReqId, TransferSpec]
reqs_to_flush: set[str] | None = None

View File

@@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.logger import init_logger
from vllm.v1.kv_offload.worker.worker import TransferType
logger = init_logger(__name__)
@dataclass
class OffloadingOperationMetrics:
op_size: int
op_time: float
@dataclass
class OffloadingConnectorStats(KVConnectorStats):
def __post_init__(self):
if not self.data:
# Empty container init, no data is passed in.
self.reset()
def reset(self):
self.data: dict[str, list[OffloadingOperationMetrics]] = {}
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
if not other.is_empty():
for k, v in other.data.items():
if k not in self.data:
self.data[k] = v
else:
accumulator = self.data[k]
assert isinstance(accumulator, list)
accumulator.extend(v)
return self
def reduce(self) -> dict[str, int | float]:
"""
Reduce the observations collected during a time interval to one or
more representative values (eg avg/median/sum of the series).
This is meant to be called by the logger to produce a summary of the
stats for the last time interval.
"""
return_dict: dict[str, int | float] = {}
for transfer_type, ops_list in self.data.items():
assert isinstance(ops_list, list)
total_bytes = 0
total_time = 0.0
for op in ops_list:
assert isinstance(op, dict)
total_bytes += op["op_size"]
total_time += op["op_time"]
return_dict[f"{transfer_type}_total_bytes"] = total_bytes
return_dict[f"{transfer_type}_total_time"] = total_time
return return_dict
def is_empty(self) -> bool:
return not self.data
def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType):
src, dst = transfer_type
transfer_type_key = src + "_to_" + dst
op = OffloadingOperationMetrics(num_bytes, time)
if transfer_type_key in self.data:
self.data[transfer_type_key].append(op)
else:
self.data[transfer_type_key] = [op]
class OffloadPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
# (engine_idx, transfer_type) -> (metric with bounded labels)
self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {}
self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {}
self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {}
buckets = [ # In bytes
1e6,
5e6,
10e6,
20e6,
40e6,
60e6,
80e6,
100e6,
150e6,
200e6,
]
self._counter_kv_bytes = self._counter_cls(
name="vllm:kv_offload_total_bytes",
documentation="Number of bytes offloaded by KV connector",
labelnames=labelnames + ["transfer_type"],
)
self._counter_kv_transfer_time = self._counter_cls(
name="vllm:kv_offload_total_time",
documentation="Total time measured by all KV offloading operations",
labelnames=labelnames + ["transfer_type"],
)
self._histogram_transfer_size = self._histogram_cls(
name="vllm:kv_offload_size",
documentation="Histogram of KV offload transfer size, in bytes.",
buckets=buckets[:],
labelnames=labelnames + ["transfer_type"],
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
"""
Observe transfer statistics from the new data structure.
transfer_stats_data is expected to be a dict where:
- keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
- values are lists of OffloadingOperationMetrics objects
"""
for transfer_type, ops in transfer_stats_data.items():
# Cache:
if (engine_idx, transfer_type) not in self.histogram_transfer_size:
self.histogram_transfer_size[(engine_idx, transfer_type)] = (
self._histogram_transfer_size.labels(
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
)
)
self.counter_kv_bytes[(engine_idx, transfer_type)] = (
self._counter_kv_bytes.labels(
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
)
)
self.counter_kv_transfer_time[(engine_idx, transfer_type)] = (
self._counter_kv_transfer_time.labels(
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
)
)
# Process ops:
assert isinstance(ops, list)
for op in ops: # ops is a list of serialized OffloadingOperationMetrics
assert isinstance(op, dict)
# Observe size histogram
self.histogram_transfer_size[(engine_idx, transfer_type)].observe(
op["op_size"]
)
# Increment byte and time counters
self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"])
self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc(
op["op_time"]
)

View File

@@ -0,0 +1,347 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable
from itertools import islice
from typing import Any
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
OffloadingConnectorMetadata,
ReqId,
)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_offload.abstract import OffloadingManager
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import TransferSpec
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import Request
logger = init_logger(__name__)
class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, spec: OffloadingSpec):
assert len(spec.gpu_block_size) == 1
self.gpu_block_size = spec.gpu_block_size[0]
self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor
self.block_size_factor = spec.block_size_factor
self.manager: OffloadingManager = spec.get_manager()
self._requests: dict[ReqId, Request] = {}
# list of GPU block IDs per request
self._request_block_ids: dict[ReqId, list[int]] = {}
# requests to load for the current scheduler step
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
# request blocks are stored in order
# index of next block (of size offloaded_block_size) to offload
self._next_stored_block_idx: dict[ReqId, int] = {}
# if GPU prefix caching is enabled,
# track loaded blocks to avoid redundant loads
self._blocks_being_loaded: set[BlockHash] | None = (
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
)
# request ID -> set(block hashes being stored/load)
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
def _get_block_hashes(
self,
req: Request,
start_idx: int = 0,
end_idx: int | None = None,
) -> Iterable[BlockHash]:
return islice(
req.block_hashes,
self.block_size_factor * start_idx + self.block_size_factor - 1,
self.block_size_factor * end_idx if end_idx else None,
self.block_size_factor,
)
def get_num_new_matched_tokens(
self, request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded beyond the
num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- The number of tokens that can be loaded beyond what is
already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if tokens will be loaded asynchronously
(between scheduler steps).
"""
num_blocks = request.num_tokens // self.offloaded_block_size
assert len(request.block_hashes) // self.block_size_factor == num_blocks
block_hashes = self._get_block_hashes(request)
self.manager.touch(block_hashes)
full_block_tokens = self.offloaded_block_size * num_blocks
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
# we can load less than a block, skip
return 0, False
start_block_idx = num_computed_tokens // self.offloaded_block_size
hits = self.manager.lookup(
self._get_block_hashes(request, start_idx=start_block_idx)
)
if hits is None:
# indicates a lookup that should be tried later
return None, False
if hits == 0:
return 0, False
num_hit_tokens = (
self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
)
logger.debug(
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
request.request_id,
num_hit_tokens,
num_computed_tokens,
)
if num_hit_tokens < self.offloaded_block_size:
return 0, False
if self._blocks_being_loaded:
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
)
if any(
block_hash in self._blocks_being_loaded for block_hash in block_hashes
):
# hit blocks are being loaded, delay request
logger.debug(
"Delaying request %s since some of its blocks are already"
" being loaded",
request.request_id,
)
return None, False
return num_hit_tokens, True
def update_state_after_alloc(
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
self._requests[request.request_id] = request
# the block ids are updated in _get_reqs_to_store
self._request_block_ids[request.request_id] = []
if num_external_tokens == 0:
return
block_groups = blocks.get_block_ids()
block_ids = block_groups[0]
num_computed_gpu_blocks = sum(
block.block_hash is not None for block in blocks.blocks[0]
)
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
full_block_tokens = num_computed_tokens + num_external_tokens
assert full_block_tokens % self.offloaded_block_size == 0
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size
start_block_idx = num_computed_tokens // self.offloaded_block_size
num_blocks = full_block_tokens // self.offloaded_block_size
assert len(request.block_hashes) // self.block_size_factor >= num_blocks
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=num_blocks
)
src_spec = self.manager.prepare_load(block_hashes)
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=num_blocks
)
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
req_blocks_being_loaded.update(block_hashes)
self._next_stored_block_idx[request.request_id] = num_blocks
if self._blocks_being_loaded is not None:
self._blocks_being_loaded.update(req_blocks_being_loaded)
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
reqs_to_store: dict[ReqId, TransferSpec] = {}
# iterate over both new and cached requests
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
if preempted:
self._request_block_ids[req_id] = []
if new_block_id_groups:
new_block_ids = new_block_id_groups[0]
self._request_block_ids[req_id] += new_block_ids
block_ids = self._request_block_ids[req_id]
req = self._requests[req_id]
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
expected_tokens = req.num_computed_tokens + new_tokens
# with async scheduling, some tokens may be missing
total_tokens = min(expected_tokens, req.num_tokens)
num_blocks = total_tokens // self.offloaded_block_size
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
num_new_blocks = num_blocks - start_block_idx
if num_new_blocks <= 0:
continue
num_gpu_blocks = num_blocks * self.block_size_factor
assert len(req.block_hashes) >= num_gpu_blocks
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks
)
store_output = self.manager.prepare_store(new_block_hashes)
if store_output is None:
logger.warning(
"Request %s: cannot store %s blocks", req_id, num_new_blocks
)
continue
self._next_stored_block_idx[req_id] = num_blocks
if not store_output.block_hashes_to_store:
continue
block_hashes_to_store = set(store_output.block_hashes_to_store)
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
self.manager.touch(block_hashes)
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks
)
dst_spec = store_output.store_spec
src_block_ids: list[int] = []
for idx, blk_hash in enumerate(new_block_hashes):
if blk_hash not in block_hashes_to_store:
continue
offloaded_block_idx = start_block_idx + idx
gpu_block_idx = offloaded_block_idx * self.block_size_factor
for i in range(self.block_size_factor):
src_block_ids.append(block_ids[gpu_block_idx + i])
src_spec = GPULoadStoreSpec(src_block_ids)
reqs_to_store[req_id] = (src_spec, dst_spec)
self._reqs_being_stored[req_id] |= block_hashes_to_store
logger.debug(
"Request %s offloading %s blocks starting from block #%d",
req_id,
len(block_hashes_to_store),
start_block_idx,
)
return reqs_to_store
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
meta = OffloadingConnectorMetadata(
reqs_to_load=self._reqs_to_load,
reqs_to_store=self._get_reqs_to_store(scheduler_output),
reqs_to_flush=scheduler_output.preempted_req_ids,
)
self._reqs_to_load = {}
# NOTE (orozery): we should move this logic to update_connector_output
# once KVConnectorOutput allows us to report completed transfers
for req_id in scheduler_output.preempted_req_ids or ():
block_hashes = self._reqs_being_stored.get(req_id)
if block_hashes:
self.manager.complete_store(block_hashes)
block_hashes.clear()
return meta
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
for req_id in connector_output.finished_sending or []:
block_hashes = self._reqs_being_stored.pop(req_id, None)
if block_hashes:
self.manager.complete_store(block_hashes)
for req_id in connector_output.finished_recving or []:
block_hashes = self._reqs_being_loaded.pop(req_id, None)
if block_hashes:
if self._blocks_being_loaded:
self._blocks_being_loaded.difference_update(block_hashes)
self.manager.complete_load(block_hashes)
def request_finished(
self,
request: Request,
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
req_id = request.request_id
self._requests.pop(req_id, None)
self._request_block_ids.pop(req_id, None)
# TODO(orozery): possibly kickoff offload for last block
# which may have been deferred due to async scheduling
self._next_stored_block_idx.pop(req_id, None)
request_being_stored = req_id in self._reqs_being_stored
return request_being_stored, None
def take_events(self) -> Iterable[KVCacheEvent]:
"""Take the KV cache events from the connector.
Returns:
A list of KV cache events.
"""
for event in self.manager.take_events():
if event.removed:
yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
else:
yield BlockStored(
block_hashes=event.block_hashes,
parent_block_hash=None,
token_ids=[],
lora_id=None,
block_size=event.block_size,
medium=event.medium,
lora_name=None,
)

View File

@@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
import torch
from vllm.config import get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
OffloadingConnectorMetadata,
ReqId,
)
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import (
OffloadingConnectorStats,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import (
OffloadingWorker,
TransferSpec,
)
logger = init_logger(__name__)
class OffloadingConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, spec: OffloadingSpec):
self.spec = spec
self.worker = OffloadingWorker()
self._job_counter = 0
self.kv_connector_stats = OffloadingConnectorStats()
# req_id -> (job_id, store)
self._jobs: dict[int, tuple[ReqId, bool]] = {}
# req_id -> active job IDs
self._load_job: dict[ReqId, int] = {}
# req_id -> set(active job IDs)
self._store_jobs = defaultdict[ReqId, set[int]](set)
# list of store jobs pending submission (job_id, transfer_spec)
self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
self._finished_reqs_waiting_for_store: set[ReqId] = set()
def _generate_job_id(self) -> int:
job_id = self._job_counter
self._job_counter = job_id + 1
return job_id
def _register_handlers(
self,
kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
):
for src_cls, dst_cls, handler in self.spec.get_handlers(
kv_caches, attn_backends
):
self.worker.register_handler(src_cls, dst_cls, handler)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
layer_names = list(kv_caches.keys())
layers = get_layers_from_vllm_config(
self.spec.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
layer_names,
)
attn_backends = {
layer_name: layers[layer_name].get_attn_backend()
for layer_name in layer_names
}
self._register_handlers(kv_caches, attn_backends)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
cross_layer_name = "ALL_LAYERS"
kv_caches = {cross_layer_name: kv_cache}
attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends)
def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
self._unsubmitted_store_jobs.clear()
for req_id in kv_connector_metadata.reqs_to_flush or ():
job_ids = self._store_jobs.get(req_id)
if job_ids:
self.worker.wait(job_ids)
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
self._unsubmitted_store_jobs.clear()
for req_id, transfer_spec in metadata.reqs_to_load.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, False)
assert req_id not in self._load_job
self._load_job[req_id] = job_id
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_store.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, True)
self._store_jobs[req_id].add(job_id)
# NOTE(orozery): defer the store to the beginning of the next engine step,
# so that offloading starts AFTER transfers related to token sampling,
# thereby avoiding delays to token generation due to offloading.
self._unsubmitted_store_jobs.append((job_id, transfer_spec))
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns a list of request IDs that finished loading or storing.
Returns:
ids of requests that have finished asynchronous transfer
tuple of (sending/saving ids, recving/loading ids).
"""
finished_sending = set()
finished_recving = set()
for transfer_result in self.worker.get_finished():
# we currently do not support job failures
job_id = transfer_result.job_id
assert transfer_result.success
req_id, store = self._jobs.pop(job_id)
if (
transfer_result.transfer_time
and transfer_result.transfer_size is not None
and transfer_result.transfer_type is not None
):
self.kv_connector_stats.record_transfer(
num_bytes=transfer_result.transfer_size,
time=transfer_result.transfer_time,
transfer_type=transfer_result.transfer_type,
)
if store:
req_jobs = self._store_jobs[req_id]
req_jobs.remove(job_id)
if req_jobs:
continue
if req_id in self._finished_reqs_waiting_for_store:
self._finished_reqs_waiting_for_store.remove(req_id)
finished_sending.add(req_id)
del self._store_jobs[req_id]
else:
req_job = self._load_job[req_id]
assert job_id == req_job
del self._load_job[req_id]
finished_recving.add(req_id)
for req_id in finished_req_ids:
pending_req_jobs = self._store_jobs.get(req_id)
if pending_req_jobs:
self._finished_reqs_waiting_for_store.add(req_id)
elif pending_req_jobs is not None:
finished_sending.add(req_id)
del self._store_jobs[req_id]
return finished_sending, finished_recving
def get_kv_connector_stats(self) -> KVConnectorStats | None:
"""
Get the KV transfer stats for the connector.
"""
if self.kv_connector_stats.is_empty():
return None
# Clear stats for next iteration
kv_connector_stats = self.kv_connector_stats
self.kv_connector_stats = OffloadingConnectorStats()
return kv_connector_stats

View File

@@ -1,16 +1,12 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass
from itertools import islice
from typing import Any from typing import Any
import torch import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorRole, KVConnectorRole,
@@ -22,97 +18,28 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
PromMetric, PromMetric,
PromMetricT, PromMetricT,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
OffloadingConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import (
OffloadingConnectorStats,
OffloadPromMetrics,
)
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler import (
OffloadingConnectorScheduler,
)
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.worker import (
OffloadingConnectorWorker,
)
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_offload.abstract import OffloadingManager
from vllm.v1.kv_offload.factory import OffloadingSpecFactory from vllm.v1.kv_offload.factory import OffloadingSpecFactory
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import (
OffloadingWorker,
TransferSpec,
TransferType,
)
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import Request from vllm.v1.request import Request
ReqId = str
logger = init_logger(__name__)
@dataclass
class OffloadingOperationMetrics:
op_size: int
op_time: float
@dataclass
class OffloadingConnectorStats(KVConnectorStats):
def __post_init__(self):
if not self.data:
# Empty container init, no data is passed in.
self.reset()
def reset(self):
self.data: dict[str, list[OffloadingOperationMetrics]] = {}
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
if not other.is_empty():
for k, v in other.data.items():
if k not in self.data:
self.data[k] = v
else:
accumulator = self.data[k]
assert isinstance(accumulator, list)
accumulator.extend(v)
return self
def reduce(self) -> dict[str, int | float]:
"""
Reduce the observations collected during a time interval to one or
more representative values (eg avg/median/sum of the series).
This is meant to be called by the logger to produce a summary of the
stats for the last time interval.
"""
return_dict: dict[str, int | float] = {}
for transfer_type, ops_list in self.data.items():
assert isinstance(ops_list, list)
total_bytes = 0
total_time = 0.0
for op in ops_list:
assert isinstance(op, dict)
total_bytes += op["op_size"]
total_time += op["op_time"]
return_dict[f"{transfer_type}_total_bytes"] = total_bytes
return_dict[f"{transfer_type}_total_time"] = total_time
return return_dict
def is_empty(self) -> bool:
return not self.data
def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType):
src, dst = transfer_type
transfer_type_key = src + "_to_" + dst
op = OffloadingOperationMetrics(num_bytes, time)
if transfer_type_key in self.data:
self.data[transfer_type_key].append(op)
else:
self.data[transfer_type_key] = [op]
@dataclass
class OffloadingConnectorMetadata(KVConnectorMetadata):
reqs_to_load: dict[ReqId, TransferSpec]
reqs_to_store: dict[ReqId, TransferSpec]
reqs_to_flush: set[str] | None = None
class OffloadingConnector(KVConnectorBase_V1): class OffloadingConnector(KVConnectorBase_V1):
@property @property
@@ -242,571 +169,3 @@ class OffloadingConnector(KVConnectorBase_V1):
return OffloadPromMetrics( return OffloadPromMetrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues vllm_config, metric_types, labelnames, per_engine_labelvalues
) )
class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, spec: OffloadingSpec):
assert len(spec.gpu_block_size) == 1
self.gpu_block_size = spec.gpu_block_size[0]
self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor
self.block_size_factor = spec.block_size_factor
self.manager: OffloadingManager = spec.get_manager()
self._requests: dict[ReqId, Request] = {}
# list of GPU block IDs per request
self._request_block_ids: dict[ReqId, list[int]] = {}
# requests to load for the current scheduler step
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
# request blocks are stored in order
# index of next block (of size offloaded_block_size) to offload
self._next_stored_block_idx: dict[ReqId, int] = {}
# if GPU prefix caching is enabled,
# track loaded blocks to avoid redundant loads
self._blocks_being_loaded: set[BlockHash] | None = (
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
)
# request ID -> set(block hashes being stored/load)
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
def _get_block_hashes(
self,
req: Request,
start_idx: int = 0,
end_idx: int | None = None,
) -> Iterable[BlockHash]:
return islice(
req.block_hashes,
self.block_size_factor * start_idx + self.block_size_factor - 1,
self.block_size_factor * end_idx if end_idx else None,
self.block_size_factor,
)
def get_num_new_matched_tokens(
self, request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded beyond the
num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- The number of tokens that can be loaded beyond what is
already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if tokens will be loaded asynchronously
(between scheduler steps).
"""
num_blocks = request.num_tokens // self.offloaded_block_size
assert len(request.block_hashes) // self.block_size_factor == num_blocks
block_hashes = self._get_block_hashes(request)
self.manager.touch(block_hashes)
full_block_tokens = self.offloaded_block_size * num_blocks
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
# we can load less than a block, skip
return 0, False
start_block_idx = num_computed_tokens // self.offloaded_block_size
hits = self.manager.lookup(
self._get_block_hashes(request, start_idx=start_block_idx)
)
if hits is None:
# indicates a lookup that should be tried later
return None, False
if hits == 0:
return 0, False
num_hit_tokens = (
self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
)
logger.debug(
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
request.request_id,
num_hit_tokens,
num_computed_tokens,
)
if num_hit_tokens < self.offloaded_block_size:
return 0, False
if self._blocks_being_loaded:
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
)
if any(
block_hash in self._blocks_being_loaded for block_hash in block_hashes
):
# hit blocks are being loaded, delay request
logger.debug(
"Delaying request %s since some of its blocks are already"
" being loaded",
request.request_id,
)
return None, False
return num_hit_tokens, True
def update_state_after_alloc(
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
self._requests[request.request_id] = request
# the block ids are updated in _get_reqs_to_store
self._request_block_ids[request.request_id] = []
if num_external_tokens == 0:
return
block_groups = blocks.get_block_ids()
block_ids = block_groups[0]
num_computed_gpu_blocks = sum(
block.block_hash is not None for block in blocks.blocks[0]
)
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
full_block_tokens = num_computed_tokens + num_external_tokens
assert full_block_tokens % self.offloaded_block_size == 0
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size
start_block_idx = num_computed_tokens // self.offloaded_block_size
num_blocks = full_block_tokens // self.offloaded_block_size
assert len(request.block_hashes) // self.block_size_factor >= num_blocks
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=num_blocks
)
src_spec = self.manager.prepare_load(block_hashes)
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=num_blocks
)
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
req_blocks_being_loaded.update(block_hashes)
self._next_stored_block_idx[request.request_id] = num_blocks
if self._blocks_being_loaded is not None:
self._blocks_being_loaded.update(req_blocks_being_loaded)
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
reqs_to_store: dict[ReqId, TransferSpec] = {}
# iterate over both new and cached requests
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
if preempted:
self._request_block_ids[req_id] = []
if new_block_id_groups:
new_block_ids = new_block_id_groups[0]
self._request_block_ids[req_id] += new_block_ids
block_ids = self._request_block_ids[req_id]
req = self._requests[req_id]
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
expected_tokens = req.num_computed_tokens + new_tokens
# with async scheduling, some tokens may be missing
total_tokens = min(expected_tokens, req.num_tokens)
num_blocks = total_tokens // self.offloaded_block_size
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
num_new_blocks = num_blocks - start_block_idx
if num_new_blocks <= 0:
continue
num_gpu_blocks = num_blocks * self.block_size_factor
assert len(req.block_hashes) >= num_gpu_blocks
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks
)
store_output = self.manager.prepare_store(new_block_hashes)
if store_output is None:
logger.warning(
"Request %s: cannot store %s blocks", req_id, num_new_blocks
)
continue
self._next_stored_block_idx[req_id] = num_blocks
if not store_output.block_hashes_to_store:
continue
block_hashes_to_store = set(store_output.block_hashes_to_store)
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
self.manager.touch(block_hashes)
new_block_hashes = self._get_block_hashes(
req, start_idx=start_block_idx, end_idx=num_blocks
)
dst_spec = store_output.store_spec
src_block_ids: list[int] = []
for idx, blk_hash in enumerate(new_block_hashes):
if blk_hash not in block_hashes_to_store:
continue
offloaded_block_idx = start_block_idx + idx
gpu_block_idx = offloaded_block_idx * self.block_size_factor
for i in range(self.block_size_factor):
src_block_ids.append(block_ids[gpu_block_idx + i])
src_spec = GPULoadStoreSpec(src_block_ids)
reqs_to_store[req_id] = (src_spec, dst_spec)
self._reqs_being_stored[req_id] |= block_hashes_to_store
logger.debug(
"Request %s offloading %s blocks starting from block #%d",
req_id,
len(block_hashes_to_store),
start_block_idx,
)
return reqs_to_store
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
meta = OffloadingConnectorMetadata(
reqs_to_load=self._reqs_to_load,
reqs_to_store=self._get_reqs_to_store(scheduler_output),
reqs_to_flush=scheduler_output.preempted_req_ids,
)
self._reqs_to_load = {}
# NOTE (orozery): we should move this logic to update_connector_output
# once KVConnectorOutput allows us to report completed transfers
for req_id in scheduler_output.preempted_req_ids or ():
block_hashes = self._reqs_being_stored.get(req_id)
if block_hashes:
self.manager.complete_store(block_hashes)
block_hashes.clear()
return meta
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
for req_id in connector_output.finished_sending or []:
block_hashes = self._reqs_being_stored.pop(req_id, None)
if block_hashes:
self.manager.complete_store(block_hashes)
for req_id in connector_output.finished_recving or []:
block_hashes = self._reqs_being_loaded.pop(req_id, None)
if block_hashes:
if self._blocks_being_loaded:
self._blocks_being_loaded.difference_update(block_hashes)
self.manager.complete_load(block_hashes)
def request_finished(
self,
request: Request,
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
req_id = request.request_id
self._requests.pop(req_id, None)
self._request_block_ids.pop(req_id, None)
# TODO(orozery): possibly kickoff offload for last block
# which may have been deferred due to async scheduling
self._next_stored_block_idx.pop(req_id, None)
request_being_stored = req_id in self._reqs_being_stored
return request_being_stored, None
def take_events(self) -> Iterable[KVCacheEvent]:
"""Take the KV cache events from the connector.
Returns:
A list of KV cache events.
"""
for event in self.manager.take_events():
if event.removed:
yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
else:
yield BlockStored(
block_hashes=event.block_hashes,
parent_block_hash=None,
token_ids=[],
lora_id=None,
block_size=event.block_size,
medium=event.medium,
lora_name=None,
)
class OffloadingConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, spec: OffloadingSpec):
self.spec = spec
self.worker = OffloadingWorker()
self._job_counter = 0
self.kv_connector_stats = OffloadingConnectorStats()
# req_id -> (job_id, store)
self._jobs: dict[int, tuple[ReqId, bool]] = {}
# req_id -> active job IDs
self._load_job: dict[ReqId, int] = {}
# req_id -> set(active job IDs)
self._store_jobs = defaultdict[ReqId, set[int]](set)
# list of store jobs pending submission (job_id, transfer_spec)
self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
self._finished_reqs_waiting_for_store: set[ReqId] = set()
def _generate_job_id(self) -> int:
job_id = self._job_counter
self._job_counter = job_id + 1
return job_id
def _register_handlers(
self,
kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
):
for src_cls, dst_cls, handler in self.spec.get_handlers(
kv_caches, attn_backends
):
self.worker.register_handler(src_cls, dst_cls, handler)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
layer_names = list(kv_caches.keys())
layers = get_layers_from_vllm_config(
self.spec.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
layer_names,
)
attn_backends = {
layer_name: layers[layer_name].get_attn_backend()
for layer_name in layer_names
}
self._register_handlers(kv_caches, attn_backends)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
cross_layer_name = "ALL_LAYERS"
kv_caches = {cross_layer_name: kv_cache}
attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends)
def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
self._unsubmitted_store_jobs.clear()
for req_id in kv_connector_metadata.reqs_to_flush or ():
job_ids = self._store_jobs.get(req_id)
if job_ids:
self.worker.wait(job_ids)
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
self._unsubmitted_store_jobs.clear()
for req_id, transfer_spec in metadata.reqs_to_load.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, False)
assert req_id not in self._load_job
self._load_job[req_id] = job_id
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_store.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, True)
self._store_jobs[req_id].add(job_id)
# NOTE(orozery): defer the store to the beginning of the next engine step,
# so that offloading starts AFTER transfers related to token sampling,
# thereby avoiding delays to token generation due to offloading.
self._unsubmitted_store_jobs.append((job_id, transfer_spec))
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns a list of request IDs that finished loading or storing.
Returns:
ids of requests that have finished asynchronous transfer
tuple of (sending/saving ids, recving/loading ids).
"""
finished_sending = set()
finished_recving = set()
for transfer_result in self.worker.get_finished():
# we currently do not support job failures
job_id = transfer_result.job_id
assert transfer_result.success
req_id, store = self._jobs.pop(job_id)
if (
transfer_result.transfer_time
and transfer_result.transfer_size is not None
and transfer_result.transfer_type is not None
):
self.kv_connector_stats.record_transfer(
num_bytes=transfer_result.transfer_size,
time=transfer_result.transfer_time,
transfer_type=transfer_result.transfer_type,
)
if store:
req_jobs = self._store_jobs[req_id]
req_jobs.remove(job_id)
if req_jobs:
continue
if req_id in self._finished_reqs_waiting_for_store:
self._finished_reqs_waiting_for_store.remove(req_id)
finished_sending.add(req_id)
del self._store_jobs[req_id]
else:
req_job = self._load_job[req_id]
assert job_id == req_job
del self._load_job[req_id]
finished_recving.add(req_id)
for req_id in finished_req_ids:
pending_req_jobs = self._store_jobs.get(req_id)
if pending_req_jobs:
self._finished_reqs_waiting_for_store.add(req_id)
elif pending_req_jobs is not None:
finished_sending.add(req_id)
del self._store_jobs[req_id]
return finished_sending, finished_recving
def get_kv_connector_stats(self) -> KVConnectorStats | None:
"""
Get the KV transfer stats for the connector.
"""
if self.kv_connector_stats.is_empty():
return None
# Clear stats for next iteration
kv_connector_stats = self.kv_connector_stats
self.kv_connector_stats = OffloadingConnectorStats()
return kv_connector_stats
class OffloadPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
# (engine_idx, transfer_type) -> (metric with bounded labels)
self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {}
self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {}
self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {}
buckets = [ # In bytes
1e6,
5e6,
10e6,
20e6,
40e6,
60e6,
80e6,
100e6,
150e6,
200e6,
]
self._counter_kv_bytes = self._counter_cls(
name="vllm:kv_offload_total_bytes",
documentation="Number of bytes offloaded by KV connector",
labelnames=labelnames + ["transfer_type"],
)
self._counter_kv_transfer_time = self._counter_cls(
name="vllm:kv_offload_total_time",
documentation="Total time measured by all KV offloading operations",
labelnames=labelnames + ["transfer_type"],
)
self._histogram_transfer_size = self._histogram_cls(
name="vllm:kv_offload_size",
documentation="Histogram of KV offload transfer size, in bytes.",
buckets=buckets[:],
labelnames=labelnames + ["transfer_type"],
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
"""
Observe transfer statistics from the new data structure.
transfer_stats_data is expected to be a dict where:
- keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
- values are lists of OffloadingOperationMetrics objects
"""
for transfer_type, ops in transfer_stats_data.items():
# Cache:
if (engine_idx, transfer_type) not in self.histogram_transfer_size:
self.histogram_transfer_size[(engine_idx, transfer_type)] = (
self._histogram_transfer_size.labels(
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
)
)
self.counter_kv_bytes[(engine_idx, transfer_type)] = (
self._counter_kv_bytes.labels(
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
)
)
self.counter_kv_transfer_time[(engine_idx, transfer_type)] = (
self._counter_kv_transfer_time.labels(
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
)
)
# Process ops:
assert isinstance(ops, list)
for op in ops: # ops is a list of serialized OffloadingOperationMetrics
assert isinstance(op, dict)
# Observe size histogram
self.histogram_transfer_size[(engine_idx, transfer_type)].observe(
op["op_size"]
)
# Increment byte and time counters
self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"])
self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc(
op["op_time"]
)