[kv_offload+HMA][6/N]: Split offloading_connector.py (#37405)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -13,10 +13,14 @@ from vllm import SamplingParams
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
|
||||
OffloadingConnectorMetadata,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import (
|
||||
OffloadingConnectorStats,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
|
||||
OffloadingConnector,
|
||||
OffloadingConnectorMetadata,
|
||||
OffloadingConnectorStats,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils.hashing import sha256
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.v1.kv_offload.worker.worker import TransferSpec
|
||||
|
||||
ReqId = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
reqs_to_load: dict[ReqId, TransferSpec]
|
||||
reqs_to_store: dict[ReqId, TransferSpec]
|
||||
reqs_to_flush: set[str] | None = None
|
||||
@@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_offload.worker.worker import TransferType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingOperationMetrics:
|
||||
op_size: int
|
||||
op_time: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorStats(KVConnectorStats):
|
||||
def __post_init__(self):
|
||||
if not self.data:
|
||||
# Empty container init, no data is passed in.
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.data: dict[str, list[OffloadingOperationMetrics]] = {}
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
if not other.is_empty():
|
||||
for k, v in other.data.items():
|
||||
if k not in self.data:
|
||||
self.data[k] = v
|
||||
else:
|
||||
accumulator = self.data[k]
|
||||
assert isinstance(accumulator, list)
|
||||
accumulator.extend(v)
|
||||
return self
|
||||
|
||||
def reduce(self) -> dict[str, int | float]:
|
||||
"""
|
||||
Reduce the observations collected during a time interval to one or
|
||||
more representative values (eg avg/median/sum of the series).
|
||||
This is meant to be called by the logger to produce a summary of the
|
||||
stats for the last time interval.
|
||||
"""
|
||||
return_dict: dict[str, int | float] = {}
|
||||
for transfer_type, ops_list in self.data.items():
|
||||
assert isinstance(ops_list, list)
|
||||
total_bytes = 0
|
||||
total_time = 0.0
|
||||
for op in ops_list:
|
||||
assert isinstance(op, dict)
|
||||
total_bytes += op["op_size"]
|
||||
total_time += op["op_time"]
|
||||
return_dict[f"{transfer_type}_total_bytes"] = total_bytes
|
||||
return_dict[f"{transfer_type}_total_time"] = total_time
|
||||
return return_dict
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.data
|
||||
|
||||
def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType):
|
||||
src, dst = transfer_type
|
||||
transfer_type_key = src + "_to_" + dst
|
||||
op = OffloadingOperationMetrics(num_bytes, time)
|
||||
if transfer_type_key in self.data:
|
||||
self.data[transfer_type_key].append(op)
|
||||
else:
|
||||
self.data[transfer_type_key] = [op]
|
||||
|
||||
|
||||
class OffloadPromMetrics(KVConnectorPromMetrics):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||
# (engine_idx, transfer_type) -> (metric with bounded labels)
|
||||
self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {}
|
||||
self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {}
|
||||
self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {}
|
||||
buckets = [ # In bytes
|
||||
1e6,
|
||||
5e6,
|
||||
10e6,
|
||||
20e6,
|
||||
40e6,
|
||||
60e6,
|
||||
80e6,
|
||||
100e6,
|
||||
150e6,
|
||||
200e6,
|
||||
]
|
||||
|
||||
self._counter_kv_bytes = self._counter_cls(
|
||||
name="vllm:kv_offload_total_bytes",
|
||||
documentation="Number of bytes offloaded by KV connector",
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
self._counter_kv_transfer_time = self._counter_cls(
|
||||
name="vllm:kv_offload_total_time",
|
||||
documentation="Total time measured by all KV offloading operations",
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
self._histogram_transfer_size = self._histogram_cls(
|
||||
name="vllm:kv_offload_size",
|
||||
documentation="Histogram of KV offload transfer size, in bytes.",
|
||||
buckets=buckets[:],
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
"""
|
||||
Observe transfer statistics from the new data structure.
|
||||
transfer_stats_data is expected to be a dict where:
|
||||
- keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
|
||||
- values are lists of OffloadingOperationMetrics objects
|
||||
"""
|
||||
|
||||
for transfer_type, ops in transfer_stats_data.items():
|
||||
# Cache:
|
||||
if (engine_idx, transfer_type) not in self.histogram_transfer_size:
|
||||
self.histogram_transfer_size[(engine_idx, transfer_type)] = (
|
||||
self._histogram_transfer_size.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
self.counter_kv_bytes[(engine_idx, transfer_type)] = (
|
||||
self._counter_kv_bytes.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
self.counter_kv_transfer_time[(engine_idx, transfer_type)] = (
|
||||
self._counter_kv_transfer_time.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
|
||||
# Process ops:
|
||||
assert isinstance(ops, list)
|
||||
for op in ops: # ops is a list of serialized OffloadingOperationMetrics
|
||||
assert isinstance(op, dict)
|
||||
# Observe size histogram
|
||||
self.histogram_transfer_size[(engine_idx, transfer_type)].observe(
|
||||
op["op_size"]
|
||||
)
|
||||
|
||||
# Increment byte and time counters
|
||||
self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"])
|
||||
|
||||
self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc(
|
||||
op["op_time"]
|
||||
)
|
||||
@@ -0,0 +1,347 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
|
||||
OffloadingConnectorMetadata,
|
||||
ReqId,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_offload.abstract import OffloadingManager
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import TransferSpec
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OffloadingConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
assert len(spec.gpu_block_size) == 1
|
||||
self.gpu_block_size = spec.gpu_block_size[0]
|
||||
self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor
|
||||
self.block_size_factor = spec.block_size_factor
|
||||
self.manager: OffloadingManager = spec.get_manager()
|
||||
|
||||
self._requests: dict[ReqId, Request] = {}
|
||||
# list of GPU block IDs per request
|
||||
self._request_block_ids: dict[ReqId, list[int]] = {}
|
||||
# requests to load for the current scheduler step
|
||||
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
|
||||
# request blocks are stored in order
|
||||
# index of next block (of size offloaded_block_size) to offload
|
||||
self._next_stored_block_idx: dict[ReqId, int] = {}
|
||||
# if GPU prefix caching is enabled,
|
||||
# track loaded blocks to avoid redundant loads
|
||||
self._blocks_being_loaded: set[BlockHash] | None = (
|
||||
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
|
||||
)
|
||||
|
||||
# request ID -> set(block hashes being stored/load)
|
||||
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
|
||||
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
|
||||
|
||||
def _get_block_hashes(
|
||||
self,
|
||||
req: Request,
|
||||
start_idx: int = 0,
|
||||
end_idx: int | None = None,
|
||||
) -> Iterable[BlockHash]:
|
||||
return islice(
|
||||
req.block_hashes,
|
||||
self.block_size_factor * start_idx + self.block_size_factor - 1,
|
||||
self.block_size_factor * end_idx if end_idx else None,
|
||||
self.block_size_factor,
|
||||
)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: Request, num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded beyond the
|
||||
num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- The number of tokens that can be loaded beyond what is
|
||||
already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if tokens will be loaded asynchronously
|
||||
(between scheduler steps).
|
||||
"""
|
||||
num_blocks = request.num_tokens // self.offloaded_block_size
|
||||
|
||||
assert len(request.block_hashes) // self.block_size_factor == num_blocks
|
||||
block_hashes = self._get_block_hashes(request)
|
||||
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
full_block_tokens = self.offloaded_block_size * num_blocks
|
||||
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
|
||||
# we can load less than a block, skip
|
||||
return 0, False
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
hits = self.manager.lookup(
|
||||
self._get_block_hashes(request, start_idx=start_block_idx)
|
||||
)
|
||||
if hits is None:
|
||||
# indicates a lookup that should be tried later
|
||||
return None, False
|
||||
if hits == 0:
|
||||
return 0, False
|
||||
|
||||
num_hit_tokens = (
|
||||
self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
|
||||
)
|
||||
logger.debug(
|
||||
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
|
||||
request.request_id,
|
||||
num_hit_tokens,
|
||||
num_computed_tokens,
|
||||
)
|
||||
if num_hit_tokens < self.offloaded_block_size:
|
||||
return 0, False
|
||||
|
||||
if self._blocks_being_loaded:
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
|
||||
)
|
||||
|
||||
if any(
|
||||
block_hash in self._blocks_being_loaded for block_hash in block_hashes
|
||||
):
|
||||
# hit blocks are being loaded, delay request
|
||||
logger.debug(
|
||||
"Delaying request %s since some of its blocks are already"
|
||||
" being loaded",
|
||||
request.request_id,
|
||||
)
|
||||
return None, False
|
||||
|
||||
return num_hit_tokens, True
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
|
||||
):
|
||||
self._requests[request.request_id] = request
|
||||
# the block ids are updated in _get_reqs_to_store
|
||||
self._request_block_ids[request.request_id] = []
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
block_groups = blocks.get_block_ids()
|
||||
block_ids = block_groups[0]
|
||||
|
||||
num_computed_gpu_blocks = sum(
|
||||
block.block_hash is not None for block in blocks.blocks[0]
|
||||
)
|
||||
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
|
||||
full_block_tokens = num_computed_tokens + num_external_tokens
|
||||
assert full_block_tokens % self.offloaded_block_size == 0
|
||||
|
||||
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
|
||||
assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
num_blocks = full_block_tokens // self.offloaded_block_size
|
||||
|
||||
assert len(request.block_hashes) // self.block_size_factor >= num_blocks
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
|
||||
src_spec = self.manager.prepare_load(block_hashes)
|
||||
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
|
||||
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
|
||||
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
|
||||
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
|
||||
req_blocks_being_loaded.update(block_hashes)
|
||||
self._next_stored_block_idx[request.request_id] = num_blocks
|
||||
|
||||
if self._blocks_being_loaded is not None:
|
||||
self._blocks_being_loaded.update(req_blocks_being_loaded)
|
||||
|
||||
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
|
||||
reqs_to_store: dict[ReqId, TransferSpec] = {}
|
||||
# iterate over both new and cached requests
|
||||
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
|
||||
if preempted:
|
||||
self._request_block_ids[req_id] = []
|
||||
|
||||
if new_block_id_groups:
|
||||
new_block_ids = new_block_id_groups[0]
|
||||
self._request_block_ids[req_id] += new_block_ids
|
||||
|
||||
block_ids = self._request_block_ids[req_id]
|
||||
|
||||
req = self._requests[req_id]
|
||||
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
expected_tokens = req.num_computed_tokens + new_tokens
|
||||
# with async scheduling, some tokens may be missing
|
||||
total_tokens = min(expected_tokens, req.num_tokens)
|
||||
num_blocks = total_tokens // self.offloaded_block_size
|
||||
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
|
||||
num_new_blocks = num_blocks - start_block_idx
|
||||
|
||||
if num_new_blocks <= 0:
|
||||
continue
|
||||
|
||||
num_gpu_blocks = num_blocks * self.block_size_factor
|
||||
assert len(req.block_hashes) >= num_gpu_blocks
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
store_output = self.manager.prepare_store(new_block_hashes)
|
||||
if store_output is None:
|
||||
logger.warning(
|
||||
"Request %s: cannot store %s blocks", req_id, num_new_blocks
|
||||
)
|
||||
continue
|
||||
|
||||
self._next_stored_block_idx[req_id] = num_blocks
|
||||
|
||||
if not store_output.block_hashes_to_store:
|
||||
continue
|
||||
block_hashes_to_store = set(store_output.block_hashes_to_store)
|
||||
|
||||
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
dst_spec = store_output.store_spec
|
||||
src_block_ids: list[int] = []
|
||||
for idx, blk_hash in enumerate(new_block_hashes):
|
||||
if blk_hash not in block_hashes_to_store:
|
||||
continue
|
||||
offloaded_block_idx = start_block_idx + idx
|
||||
gpu_block_idx = offloaded_block_idx * self.block_size_factor
|
||||
for i in range(self.block_size_factor):
|
||||
src_block_ids.append(block_ids[gpu_block_idx + i])
|
||||
src_spec = GPULoadStoreSpec(src_block_ids)
|
||||
|
||||
reqs_to_store[req_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_stored[req_id] |= block_hashes_to_store
|
||||
|
||||
logger.debug(
|
||||
"Request %s offloading %s blocks starting from block #%d",
|
||||
req_id,
|
||||
len(block_hashes_to_store),
|
||||
start_block_idx,
|
||||
)
|
||||
|
||||
return reqs_to_store
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
meta = OffloadingConnectorMetadata(
|
||||
reqs_to_load=self._reqs_to_load,
|
||||
reqs_to_store=self._get_reqs_to_store(scheduler_output),
|
||||
reqs_to_flush=scheduler_output.preempted_req_ids,
|
||||
)
|
||||
self._reqs_to_load = {}
|
||||
|
||||
# NOTE (orozery): we should move this logic to update_connector_output
|
||||
# once KVConnectorOutput allows us to report completed transfers
|
||||
for req_id in scheduler_output.preempted_req_ids or ():
|
||||
block_hashes = self._reqs_being_stored.get(req_id)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
block_hashes.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
for req_id in connector_output.finished_sending or []:
|
||||
block_hashes = self._reqs_being_stored.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
|
||||
for req_id in connector_output.finished_recving or []:
|
||||
block_hashes = self._reqs_being_loaded.pop(req_id, None)
|
||||
if block_hashes:
|
||||
if self._blocks_being_loaded:
|
||||
self._blocks_being_loaded.difference_update(block_hashes)
|
||||
self.manager.complete_load(block_hashes)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: Request,
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
self._requests.pop(req_id, None)
|
||||
self._request_block_ids.pop(req_id, None)
|
||||
|
||||
# TODO(orozery): possibly kickoff offload for last block
|
||||
# which may have been deferred due to async scheduling
|
||||
self._next_stored_block_idx.pop(req_id, None)
|
||||
|
||||
request_being_stored = req_id in self._reqs_being_stored
|
||||
return request_being_stored, None
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
"""Take the KV cache events from the connector.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
for event in self.manager.take_events():
|
||||
if event.removed:
|
||||
yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
|
||||
else:
|
||||
yield BlockStored(
|
||||
block_hashes=event.block_hashes,
|
||||
parent_block_hash=None,
|
||||
token_ids=[],
|
||||
lora_id=None,
|
||||
block_size=event.block_size,
|
||||
medium=event.medium,
|
||||
lora_name=None,
|
||||
)
|
||||
@@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
|
||||
OffloadingConnectorMetadata,
|
||||
ReqId,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import (
|
||||
OffloadingConnectorStats,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (
|
||||
OffloadingWorker,
|
||||
TransferSpec,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OffloadingConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.spec = spec
|
||||
self.worker = OffloadingWorker()
|
||||
|
||||
self._job_counter = 0
|
||||
|
||||
self.kv_connector_stats = OffloadingConnectorStats()
|
||||
# req_id -> (job_id, store)
|
||||
self._jobs: dict[int, tuple[ReqId, bool]] = {}
|
||||
# req_id -> active job IDs
|
||||
self._load_job: dict[ReqId, int] = {}
|
||||
# req_id -> set(active job IDs)
|
||||
self._store_jobs = defaultdict[ReqId, set[int]](set)
|
||||
# list of store jobs pending submission (job_id, transfer_spec)
|
||||
self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
|
||||
|
||||
self._finished_reqs_waiting_for_store: set[ReqId] = set()
|
||||
|
||||
def _generate_job_id(self) -> int:
|
||||
job_id = self._job_counter
|
||||
self._job_counter = job_id + 1
|
||||
return job_id
|
||||
|
||||
def _register_handlers(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
):
|
||||
for src_cls, dst_cls, handler in self.spec.get_handlers(
|
||||
kv_caches, attn_backends
|
||||
):
|
||||
self.worker.register_handler(src_cls, dst_cls, handler)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
layer_names = list(kv_caches.keys())
|
||||
layers = get_layers_from_vllm_config(
|
||||
self.spec.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
layer_names,
|
||||
)
|
||||
attn_backends = {
|
||||
layer_name: layers[layer_name].get_attn_backend()
|
||||
for layer_name in layer_names
|
||||
}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
cross_layer_name = "ALL_LAYERS"
|
||||
kv_caches = {cross_layer_name: kv_cache}
|
||||
attn_backends = {cross_layer_name: attn_backend}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
self._unsubmitted_store_jobs.clear()
|
||||
|
||||
for req_id in kv_connector_metadata.reqs_to_flush or ():
|
||||
job_ids = self._store_jobs.get(req_id)
|
||||
if job_ids:
|
||||
self.worker.wait(job_ids)
|
||||
|
||||
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
self._unsubmitted_store_jobs.clear()
|
||||
|
||||
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, False)
|
||||
assert req_id not in self._load_job
|
||||
self._load_job[req_id] = job_id
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
|
||||
def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_store.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, True)
|
||||
self._store_jobs[req_id].add(job_id)
|
||||
# NOTE(orozery): defer the store to the beginning of the next engine step,
|
||||
# so that offloading starts AFTER transfers related to token sampling,
|
||||
# thereby avoiding delays to token generation due to offloading.
|
||||
self._unsubmitted_store_jobs.append((job_id, transfer_spec))
|
||||
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
Returns a list of request IDs that finished loading or storing.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
"""
|
||||
finished_sending = set()
|
||||
finished_recving = set()
|
||||
for transfer_result in self.worker.get_finished():
|
||||
# we currently do not support job failures
|
||||
job_id = transfer_result.job_id
|
||||
assert transfer_result.success
|
||||
req_id, store = self._jobs.pop(job_id)
|
||||
if (
|
||||
transfer_result.transfer_time
|
||||
and transfer_result.transfer_size is not None
|
||||
and transfer_result.transfer_type is not None
|
||||
):
|
||||
self.kv_connector_stats.record_transfer(
|
||||
num_bytes=transfer_result.transfer_size,
|
||||
time=transfer_result.transfer_time,
|
||||
transfer_type=transfer_result.transfer_type,
|
||||
)
|
||||
if store:
|
||||
req_jobs = self._store_jobs[req_id]
|
||||
req_jobs.remove(job_id)
|
||||
if req_jobs:
|
||||
continue
|
||||
|
||||
if req_id in self._finished_reqs_waiting_for_store:
|
||||
self._finished_reqs_waiting_for_store.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
else:
|
||||
req_job = self._load_job[req_id]
|
||||
assert job_id == req_job
|
||||
del self._load_job[req_id]
|
||||
finished_recving.add(req_id)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
pending_req_jobs = self._store_jobs.get(req_id)
|
||||
if pending_req_jobs:
|
||||
self._finished_reqs_waiting_for_store.add(req_id)
|
||||
elif pending_req_jobs is not None:
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
|
||||
return finished_sending, finished_recving
|
||||
|
||||
def get_kv_connector_stats(self) -> KVConnectorStats | None:
|
||||
"""
|
||||
Get the KV transfer stats for the connector.
|
||||
"""
|
||||
|
||||
if self.kv_connector_stats.is_empty():
|
||||
return None
|
||||
# Clear stats for next iteration
|
||||
kv_connector_stats = self.kv_connector_stats
|
||||
self.kv_connector_stats = OffloadingConnectorStats()
|
||||
return kv_connector_stats
|
||||
@@ -1,16 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
@@ -22,97 +18,28 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
|
||||
OffloadingConnectorMetadata,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import (
|
||||
OffloadingConnectorStats,
|
||||
OffloadPromMetrics,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler import (
|
||||
OffloadingConnectorScheduler,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.worker import (
|
||||
OffloadingConnectorWorker,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_offload.abstract import OffloadingManager
|
||||
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (
|
||||
OffloadingWorker,
|
||||
TransferSpec,
|
||||
TransferType,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
ReqId = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingOperationMetrics:
|
||||
op_size: int
|
||||
op_time: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorStats(KVConnectorStats):
|
||||
def __post_init__(self):
|
||||
if not self.data:
|
||||
# Empty container init, no data is passed in.
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.data: dict[str, list[OffloadingOperationMetrics]] = {}
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
if not other.is_empty():
|
||||
for k, v in other.data.items():
|
||||
if k not in self.data:
|
||||
self.data[k] = v
|
||||
else:
|
||||
accumulator = self.data[k]
|
||||
assert isinstance(accumulator, list)
|
||||
accumulator.extend(v)
|
||||
return self
|
||||
|
||||
def reduce(self) -> dict[str, int | float]:
|
||||
"""
|
||||
Reduce the observations collected during a time interval to one or
|
||||
more representative values (eg avg/median/sum of the series).
|
||||
This is meant to be called by the logger to produce a summary of the
|
||||
stats for the last time interval.
|
||||
"""
|
||||
return_dict: dict[str, int | float] = {}
|
||||
for transfer_type, ops_list in self.data.items():
|
||||
assert isinstance(ops_list, list)
|
||||
total_bytes = 0
|
||||
total_time = 0.0
|
||||
for op in ops_list:
|
||||
assert isinstance(op, dict)
|
||||
total_bytes += op["op_size"]
|
||||
total_time += op["op_time"]
|
||||
return_dict[f"{transfer_type}_total_bytes"] = total_bytes
|
||||
return_dict[f"{transfer_type}_total_time"] = total_time
|
||||
return return_dict
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.data
|
||||
|
||||
def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType):
|
||||
src, dst = transfer_type
|
||||
transfer_type_key = src + "_to_" + dst
|
||||
op = OffloadingOperationMetrics(num_bytes, time)
|
||||
if transfer_type_key in self.data:
|
||||
self.data[transfer_type_key].append(op)
|
||||
else:
|
||||
self.data[transfer_type_key] = [op]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
reqs_to_load: dict[ReqId, TransferSpec]
|
||||
reqs_to_store: dict[ReqId, TransferSpec]
|
||||
reqs_to_flush: set[str] | None = None
|
||||
|
||||
|
||||
class OffloadingConnector(KVConnectorBase_V1):
|
||||
@property
|
||||
@@ -242,571 +169,3 @@ class OffloadingConnector(KVConnectorBase_V1):
|
||||
return OffloadPromMetrics(
|
||||
vllm_config, metric_types, labelnames, per_engine_labelvalues
|
||||
)
|
||||
|
||||
|
||||
class OffloadingConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
assert len(spec.gpu_block_size) == 1
|
||||
self.gpu_block_size = spec.gpu_block_size[0]
|
||||
self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor
|
||||
self.block_size_factor = spec.block_size_factor
|
||||
self.manager: OffloadingManager = spec.get_manager()
|
||||
|
||||
self._requests: dict[ReqId, Request] = {}
|
||||
# list of GPU block IDs per request
|
||||
self._request_block_ids: dict[ReqId, list[int]] = {}
|
||||
# requests to load for the current scheduler step
|
||||
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
|
||||
# request blocks are stored in order
|
||||
# index of next block (of size offloaded_block_size) to offload
|
||||
self._next_stored_block_idx: dict[ReqId, int] = {}
|
||||
# if GPU prefix caching is enabled,
|
||||
# track loaded blocks to avoid redundant loads
|
||||
self._blocks_being_loaded: set[BlockHash] | None = (
|
||||
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
|
||||
)
|
||||
|
||||
# request ID -> set(block hashes being stored/load)
|
||||
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
|
||||
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
|
||||
|
||||
def _get_block_hashes(
|
||||
self,
|
||||
req: Request,
|
||||
start_idx: int = 0,
|
||||
end_idx: int | None = None,
|
||||
) -> Iterable[BlockHash]:
|
||||
return islice(
|
||||
req.block_hashes,
|
||||
self.block_size_factor * start_idx + self.block_size_factor - 1,
|
||||
self.block_size_factor * end_idx if end_idx else None,
|
||||
self.block_size_factor,
|
||||
)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: Request, num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded beyond the
|
||||
num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- The number of tokens that can be loaded beyond what is
|
||||
already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if tokens will be loaded asynchronously
|
||||
(between scheduler steps).
|
||||
"""
|
||||
num_blocks = request.num_tokens // self.offloaded_block_size
|
||||
|
||||
assert len(request.block_hashes) // self.block_size_factor == num_blocks
|
||||
block_hashes = self._get_block_hashes(request)
|
||||
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
full_block_tokens = self.offloaded_block_size * num_blocks
|
||||
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
|
||||
# we can load less than a block, skip
|
||||
return 0, False
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
hits = self.manager.lookup(
|
||||
self._get_block_hashes(request, start_idx=start_block_idx)
|
||||
)
|
||||
if hits is None:
|
||||
# indicates a lookup that should be tried later
|
||||
return None, False
|
||||
if hits == 0:
|
||||
return 0, False
|
||||
|
||||
num_hit_tokens = (
|
||||
self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
|
||||
)
|
||||
logger.debug(
|
||||
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
|
||||
request.request_id,
|
||||
num_hit_tokens,
|
||||
num_computed_tokens,
|
||||
)
|
||||
if num_hit_tokens < self.offloaded_block_size:
|
||||
return 0, False
|
||||
|
||||
if self._blocks_being_loaded:
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
|
||||
)
|
||||
|
||||
if any(
|
||||
block_hash in self._blocks_being_loaded for block_hash in block_hashes
|
||||
):
|
||||
# hit blocks are being loaded, delay request
|
||||
logger.debug(
|
||||
"Delaying request %s since some of its blocks are already"
|
||||
" being loaded",
|
||||
request.request_id,
|
||||
)
|
||||
return None, False
|
||||
|
||||
return num_hit_tokens, True
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
|
||||
):
|
||||
self._requests[request.request_id] = request
|
||||
# the block ids are updated in _get_reqs_to_store
|
||||
self._request_block_ids[request.request_id] = []
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
block_groups = blocks.get_block_ids()
|
||||
block_ids = block_groups[0]
|
||||
|
||||
num_computed_gpu_blocks = sum(
|
||||
block.block_hash is not None for block in blocks.blocks[0]
|
||||
)
|
||||
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
|
||||
full_block_tokens = num_computed_tokens + num_external_tokens
|
||||
assert full_block_tokens % self.offloaded_block_size == 0
|
||||
|
||||
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
|
||||
assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
num_blocks = full_block_tokens // self.offloaded_block_size
|
||||
|
||||
assert len(request.block_hashes) // self.block_size_factor >= num_blocks
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
|
||||
src_spec = self.manager.prepare_load(block_hashes)
|
||||
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
|
||||
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
|
||||
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
|
||||
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
|
||||
req_blocks_being_loaded.update(block_hashes)
|
||||
self._next_stored_block_idx[request.request_id] = num_blocks
|
||||
|
||||
if self._blocks_being_loaded is not None:
|
||||
self._blocks_being_loaded.update(req_blocks_being_loaded)
|
||||
|
||||
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
|
||||
reqs_to_store: dict[ReqId, TransferSpec] = {}
|
||||
# iterate over both new and cached requests
|
||||
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
|
||||
if preempted:
|
||||
self._request_block_ids[req_id] = []
|
||||
|
||||
if new_block_id_groups:
|
||||
new_block_ids = new_block_id_groups[0]
|
||||
self._request_block_ids[req_id] += new_block_ids
|
||||
|
||||
block_ids = self._request_block_ids[req_id]
|
||||
|
||||
req = self._requests[req_id]
|
||||
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
expected_tokens = req.num_computed_tokens + new_tokens
|
||||
# with async scheduling, some tokens may be missing
|
||||
total_tokens = min(expected_tokens, req.num_tokens)
|
||||
num_blocks = total_tokens // self.offloaded_block_size
|
||||
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
|
||||
num_new_blocks = num_blocks - start_block_idx
|
||||
|
||||
if num_new_blocks <= 0:
|
||||
continue
|
||||
|
||||
num_gpu_blocks = num_blocks * self.block_size_factor
|
||||
assert len(req.block_hashes) >= num_gpu_blocks
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
store_output = self.manager.prepare_store(new_block_hashes)
|
||||
if store_output is None:
|
||||
logger.warning(
|
||||
"Request %s: cannot store %s blocks", req_id, num_new_blocks
|
||||
)
|
||||
continue
|
||||
|
||||
self._next_stored_block_idx[req_id] = num_blocks
|
||||
|
||||
if not store_output.block_hashes_to_store:
|
||||
continue
|
||||
block_hashes_to_store = set(store_output.block_hashes_to_store)
|
||||
|
||||
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||
)
|
||||
dst_spec = store_output.store_spec
|
||||
src_block_ids: list[int] = []
|
||||
for idx, blk_hash in enumerate(new_block_hashes):
|
||||
if blk_hash not in block_hashes_to_store:
|
||||
continue
|
||||
offloaded_block_idx = start_block_idx + idx
|
||||
gpu_block_idx = offloaded_block_idx * self.block_size_factor
|
||||
for i in range(self.block_size_factor):
|
||||
src_block_ids.append(block_ids[gpu_block_idx + i])
|
||||
src_spec = GPULoadStoreSpec(src_block_ids)
|
||||
|
||||
reqs_to_store[req_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_stored[req_id] |= block_hashes_to_store
|
||||
|
||||
logger.debug(
|
||||
"Request %s offloading %s blocks starting from block #%d",
|
||||
req_id,
|
||||
len(block_hashes_to_store),
|
||||
start_block_idx,
|
||||
)
|
||||
|
||||
return reqs_to_store
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
meta = OffloadingConnectorMetadata(
|
||||
reqs_to_load=self._reqs_to_load,
|
||||
reqs_to_store=self._get_reqs_to_store(scheduler_output),
|
||||
reqs_to_flush=scheduler_output.preempted_req_ids,
|
||||
)
|
||||
self._reqs_to_load = {}
|
||||
|
||||
# NOTE (orozery): we should move this logic to update_connector_output
|
||||
# once KVConnectorOutput allows us to report completed transfers
|
||||
for req_id in scheduler_output.preempted_req_ids or ():
|
||||
block_hashes = self._reqs_being_stored.get(req_id)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
block_hashes.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
for req_id in connector_output.finished_sending or []:
|
||||
block_hashes = self._reqs_being_stored.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
|
||||
for req_id in connector_output.finished_recving or []:
|
||||
block_hashes = self._reqs_being_loaded.pop(req_id, None)
|
||||
if block_hashes:
|
||||
if self._blocks_being_loaded:
|
||||
self._blocks_being_loaded.difference_update(block_hashes)
|
||||
self.manager.complete_load(block_hashes)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: Request,
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
self._requests.pop(req_id, None)
|
||||
self._request_block_ids.pop(req_id, None)
|
||||
|
||||
# TODO(orozery): possibly kickoff offload for last block
|
||||
# which may have been deferred due to async scheduling
|
||||
self._next_stored_block_idx.pop(req_id, None)
|
||||
|
||||
request_being_stored = req_id in self._reqs_being_stored
|
||||
return request_being_stored, None
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
"""Take the KV cache events from the connector.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
for event in self.manager.take_events():
|
||||
if event.removed:
|
||||
yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
|
||||
else:
|
||||
yield BlockStored(
|
||||
block_hashes=event.block_hashes,
|
||||
parent_block_hash=None,
|
||||
token_ids=[],
|
||||
lora_id=None,
|
||||
block_size=event.block_size,
|
||||
medium=event.medium,
|
||||
lora_name=None,
|
||||
)
|
||||
|
||||
|
||||
class OffloadingConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.spec = spec
|
||||
self.worker = OffloadingWorker()
|
||||
|
||||
self._job_counter = 0
|
||||
|
||||
self.kv_connector_stats = OffloadingConnectorStats()
|
||||
# req_id -> (job_id, store)
|
||||
self._jobs: dict[int, tuple[ReqId, bool]] = {}
|
||||
# req_id -> active job IDs
|
||||
self._load_job: dict[ReqId, int] = {}
|
||||
# req_id -> set(active job IDs)
|
||||
self._store_jobs = defaultdict[ReqId, set[int]](set)
|
||||
# list of store jobs pending submission (job_id, transfer_spec)
|
||||
self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
|
||||
|
||||
self._finished_reqs_waiting_for_store: set[ReqId] = set()
|
||||
|
||||
def _generate_job_id(self) -> int:
|
||||
job_id = self._job_counter
|
||||
self._job_counter = job_id + 1
|
||||
return job_id
|
||||
|
||||
def _register_handlers(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
):
|
||||
for src_cls, dst_cls, handler in self.spec.get_handlers(
|
||||
kv_caches, attn_backends
|
||||
):
|
||||
self.worker.register_handler(src_cls, dst_cls, handler)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
layer_names = list(kv_caches.keys())
|
||||
layers = get_layers_from_vllm_config(
|
||||
self.spec.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
layer_names,
|
||||
)
|
||||
attn_backends = {
|
||||
layer_name: layers[layer_name].get_attn_backend()
|
||||
for layer_name in layer_names
|
||||
}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
cross_layer_name = "ALL_LAYERS"
|
||||
kv_caches = {cross_layer_name: kv_cache}
|
||||
attn_backends = {cross_layer_name: attn_backend}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
self._unsubmitted_store_jobs.clear()
|
||||
|
||||
for req_id in kv_connector_metadata.reqs_to_flush or ():
|
||||
job_ids = self._store_jobs.get(req_id)
|
||||
if job_ids:
|
||||
self.worker.wait(job_ids)
|
||||
|
||||
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
self._unsubmitted_store_jobs.clear()
|
||||
|
||||
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, False)
|
||||
assert req_id not in self._load_job
|
||||
self._load_job[req_id] = job_id
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
|
||||
def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_store.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, True)
|
||||
self._store_jobs[req_id].add(job_id)
|
||||
# NOTE(orozery): defer the store to the beginning of the next engine step,
|
||||
# so that offloading starts AFTER transfers related to token sampling,
|
||||
# thereby avoiding delays to token generation due to offloading.
|
||||
self._unsubmitted_store_jobs.append((job_id, transfer_spec))
|
||||
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
Returns a list of request IDs that finished loading or storing.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
"""
|
||||
finished_sending = set()
|
||||
finished_recving = set()
|
||||
for transfer_result in self.worker.get_finished():
|
||||
# we currently do not support job failures
|
||||
job_id = transfer_result.job_id
|
||||
assert transfer_result.success
|
||||
req_id, store = self._jobs.pop(job_id)
|
||||
if (
|
||||
transfer_result.transfer_time
|
||||
and transfer_result.transfer_size is not None
|
||||
and transfer_result.transfer_type is not None
|
||||
):
|
||||
self.kv_connector_stats.record_transfer(
|
||||
num_bytes=transfer_result.transfer_size,
|
||||
time=transfer_result.transfer_time,
|
||||
transfer_type=transfer_result.transfer_type,
|
||||
)
|
||||
if store:
|
||||
req_jobs = self._store_jobs[req_id]
|
||||
req_jobs.remove(job_id)
|
||||
if req_jobs:
|
||||
continue
|
||||
|
||||
if req_id in self._finished_reqs_waiting_for_store:
|
||||
self._finished_reqs_waiting_for_store.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
else:
|
||||
req_job = self._load_job[req_id]
|
||||
assert job_id == req_job
|
||||
del self._load_job[req_id]
|
||||
finished_recving.add(req_id)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
pending_req_jobs = self._store_jobs.get(req_id)
|
||||
if pending_req_jobs:
|
||||
self._finished_reqs_waiting_for_store.add(req_id)
|
||||
elif pending_req_jobs is not None:
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
|
||||
return finished_sending, finished_recving
|
||||
|
||||
def get_kv_connector_stats(self) -> KVConnectorStats | None:
|
||||
"""
|
||||
Get the KV transfer stats for the connector.
|
||||
"""
|
||||
|
||||
if self.kv_connector_stats.is_empty():
|
||||
return None
|
||||
# Clear stats for next iteration
|
||||
kv_connector_stats = self.kv_connector_stats
|
||||
self.kv_connector_stats = OffloadingConnectorStats()
|
||||
return kv_connector_stats
|
||||
|
||||
|
||||
class OffloadPromMetrics(KVConnectorPromMetrics):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
):
|
||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||
# (engine_idx, transfer_type) -> (metric with bounded labels)
|
||||
self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {}
|
||||
self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {}
|
||||
self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {}
|
||||
buckets = [ # In bytes
|
||||
1e6,
|
||||
5e6,
|
||||
10e6,
|
||||
20e6,
|
||||
40e6,
|
||||
60e6,
|
||||
80e6,
|
||||
100e6,
|
||||
150e6,
|
||||
200e6,
|
||||
]
|
||||
|
||||
self._counter_kv_bytes = self._counter_cls(
|
||||
name="vllm:kv_offload_total_bytes",
|
||||
documentation="Number of bytes offloaded by KV connector",
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
self._counter_kv_transfer_time = self._counter_cls(
|
||||
name="vllm:kv_offload_total_time",
|
||||
documentation="Total time measured by all KV offloading operations",
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
self._histogram_transfer_size = self._histogram_cls(
|
||||
name="vllm:kv_offload_size",
|
||||
documentation="Histogram of KV offload transfer size, in bytes.",
|
||||
buckets=buckets[:],
|
||||
labelnames=labelnames + ["transfer_type"],
|
||||
)
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
"""
|
||||
Observe transfer statistics from the new data structure.
|
||||
transfer_stats_data is expected to be a dict where:
|
||||
- keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
|
||||
- values are lists of OffloadingOperationMetrics objects
|
||||
"""
|
||||
|
||||
for transfer_type, ops in transfer_stats_data.items():
|
||||
# Cache:
|
||||
if (engine_idx, transfer_type) not in self.histogram_transfer_size:
|
||||
self.histogram_transfer_size[(engine_idx, transfer_type)] = (
|
||||
self._histogram_transfer_size.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
self.counter_kv_bytes[(engine_idx, transfer_type)] = (
|
||||
self._counter_kv_bytes.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
self.counter_kv_transfer_time[(engine_idx, transfer_type)] = (
|
||||
self._counter_kv_transfer_time.labels(
|
||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||
)
|
||||
)
|
||||
|
||||
# Process ops:
|
||||
assert isinstance(ops, list)
|
||||
for op in ops: # ops is a list of serialized OffloadingOperationMetrics
|
||||
assert isinstance(op, dict)
|
||||
# Observe size histogram
|
||||
self.histogram_transfer_size[(engine_idx, transfer_type)].observe(
|
||||
op["op_size"]
|
||||
)
|
||||
|
||||
# Increment byte and time counters
|
||||
self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"])
|
||||
|
||||
self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc(
|
||||
op["op_time"]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user