[kv_offload+HMA][6/N]: Split offloading_connector.py (#37405)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -13,10 +13,14 @@ from vllm import SamplingParams
|
|||||||
from vllm.config import KVTransferConfig, VllmConfig
|
from vllm.config import KVTransferConfig, VllmConfig
|
||||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored
|
from vllm.distributed.kv_events import BlockRemoved, BlockStored
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
|
||||||
|
OffloadingConnectorMetadata,
|
||||||
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import (
|
||||||
|
OffloadingConnectorStats,
|
||||||
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
|
||||||
OffloadingConnector,
|
OffloadingConnector,
|
||||||
OffloadingConnectorMetadata,
|
|
||||||
OffloadingConnectorStats,
|
|
||||||
)
|
)
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.utils.hashing import sha256
|
from vllm.utils.hashing import sha256
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||||
|
from vllm.v1.kv_offload.worker.worker import TransferSpec
|
||||||
|
|
||||||
|
ReqId = str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||||
|
reqs_to_load: dict[ReqId, TransferSpec]
|
||||||
|
reqs_to_store: dict[ReqId, TransferSpec]
|
||||||
|
reqs_to_flush: set[str] | None = None
|
||||||
@@ -0,0 +1,165 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||||
|
KVConnectorPromMetrics,
|
||||||
|
KVConnectorStats,
|
||||||
|
PromMetric,
|
||||||
|
PromMetricT,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.kv_offload.worker.worker import TransferType
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OffloadingOperationMetrics:
|
||||||
|
op_size: int
|
||||||
|
op_time: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OffloadingConnectorStats(KVConnectorStats):
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.data:
|
||||||
|
# Empty container init, no data is passed in.
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.data: dict[str, list[OffloadingOperationMetrics]] = {}
|
||||||
|
|
||||||
|
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||||
|
if not other.is_empty():
|
||||||
|
for k, v in other.data.items():
|
||||||
|
if k not in self.data:
|
||||||
|
self.data[k] = v
|
||||||
|
else:
|
||||||
|
accumulator = self.data[k]
|
||||||
|
assert isinstance(accumulator, list)
|
||||||
|
accumulator.extend(v)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def reduce(self) -> dict[str, int | float]:
|
||||||
|
"""
|
||||||
|
Reduce the observations collected during a time interval to one or
|
||||||
|
more representative values (eg avg/median/sum of the series).
|
||||||
|
This is meant to be called by the logger to produce a summary of the
|
||||||
|
stats for the last time interval.
|
||||||
|
"""
|
||||||
|
return_dict: dict[str, int | float] = {}
|
||||||
|
for transfer_type, ops_list in self.data.items():
|
||||||
|
assert isinstance(ops_list, list)
|
||||||
|
total_bytes = 0
|
||||||
|
total_time = 0.0
|
||||||
|
for op in ops_list:
|
||||||
|
assert isinstance(op, dict)
|
||||||
|
total_bytes += op["op_size"]
|
||||||
|
total_time += op["op_time"]
|
||||||
|
return_dict[f"{transfer_type}_total_bytes"] = total_bytes
|
||||||
|
return_dict[f"{transfer_type}_total_time"] = total_time
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return not self.data
|
||||||
|
|
||||||
|
def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType):
|
||||||
|
src, dst = transfer_type
|
||||||
|
transfer_type_key = src + "_to_" + dst
|
||||||
|
op = OffloadingOperationMetrics(num_bytes, time)
|
||||||
|
if transfer_type_key in self.data:
|
||||||
|
self.data[transfer_type_key].append(op)
|
||||||
|
else:
|
||||||
|
self.data[transfer_type_key] = [op]
|
||||||
|
|
||||||
|
|
||||||
|
class OffloadPromMetrics(KVConnectorPromMetrics):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||||
|
labelnames: list[str],
|
||||||
|
per_engine_labelvalues: dict[int, list[object]],
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||||
|
# (engine_idx, transfer_type) -> (metric with bounded labels)
|
||||||
|
self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {}
|
||||||
|
self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {}
|
||||||
|
self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {}
|
||||||
|
buckets = [ # In bytes
|
||||||
|
1e6,
|
||||||
|
5e6,
|
||||||
|
10e6,
|
||||||
|
20e6,
|
||||||
|
40e6,
|
||||||
|
60e6,
|
||||||
|
80e6,
|
||||||
|
100e6,
|
||||||
|
150e6,
|
||||||
|
200e6,
|
||||||
|
]
|
||||||
|
|
||||||
|
self._counter_kv_bytes = self._counter_cls(
|
||||||
|
name="vllm:kv_offload_total_bytes",
|
||||||
|
documentation="Number of bytes offloaded by KV connector",
|
||||||
|
labelnames=labelnames + ["transfer_type"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._counter_kv_transfer_time = self._counter_cls(
|
||||||
|
name="vllm:kv_offload_total_time",
|
||||||
|
documentation="Total time measured by all KV offloading operations",
|
||||||
|
labelnames=labelnames + ["transfer_type"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._histogram_transfer_size = self._histogram_cls(
|
||||||
|
name="vllm:kv_offload_size",
|
||||||
|
documentation="Histogram of KV offload transfer size, in bytes.",
|
||||||
|
buckets=buckets[:],
|
||||||
|
labelnames=labelnames + ["transfer_type"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||||
|
"""
|
||||||
|
Observe transfer statistics from the new data structure.
|
||||||
|
transfer_stats_data is expected to be a dict where:
|
||||||
|
- keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
|
||||||
|
- values are lists of OffloadingOperationMetrics objects
|
||||||
|
"""
|
||||||
|
|
||||||
|
for transfer_type, ops in transfer_stats_data.items():
|
||||||
|
# Cache:
|
||||||
|
if (engine_idx, transfer_type) not in self.histogram_transfer_size:
|
||||||
|
self.histogram_transfer_size[(engine_idx, transfer_type)] = (
|
||||||
|
self._histogram_transfer_size.labels(
|
||||||
|
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.counter_kv_bytes[(engine_idx, transfer_type)] = (
|
||||||
|
self._counter_kv_bytes.labels(
|
||||||
|
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.counter_kv_transfer_time[(engine_idx, transfer_type)] = (
|
||||||
|
self._counter_kv_transfer_time.labels(
|
||||||
|
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process ops:
|
||||||
|
assert isinstance(ops, list)
|
||||||
|
for op in ops: # ops is a list of serialized OffloadingOperationMetrics
|
||||||
|
assert isinstance(op, dict)
|
||||||
|
# Observe size histogram
|
||||||
|
self.histogram_transfer_size[(engine_idx, transfer_type)].observe(
|
||||||
|
op["op_size"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Increment byte and time counters
|
||||||
|
self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"])
|
||||||
|
|
||||||
|
self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc(
|
||||||
|
op["op_time"]
|
||||||
|
)
|
||||||
@@ -0,0 +1,347 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from itertools import islice
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
|
||||||
|
OffloadingConnectorMetadata,
|
||||||
|
ReqId,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.kv_offload.abstract import OffloadingManager
|
||||||
|
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||||
|
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||||
|
from vllm.v1.kv_offload.worker.worker import TransferSpec
|
||||||
|
from vllm.v1.outputs import KVConnectorOutput
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OffloadingConnectorScheduler:
|
||||||
|
"""Implementation of Scheduler side methods"""
|
||||||
|
|
||||||
|
def __init__(self, spec: OffloadingSpec):
|
||||||
|
assert len(spec.gpu_block_size) == 1
|
||||||
|
self.gpu_block_size = spec.gpu_block_size[0]
|
||||||
|
self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor
|
||||||
|
self.block_size_factor = spec.block_size_factor
|
||||||
|
self.manager: OffloadingManager = spec.get_manager()
|
||||||
|
|
||||||
|
self._requests: dict[ReqId, Request] = {}
|
||||||
|
# list of GPU block IDs per request
|
||||||
|
self._request_block_ids: dict[ReqId, list[int]] = {}
|
||||||
|
# requests to load for the current scheduler step
|
||||||
|
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
|
||||||
|
# request blocks are stored in order
|
||||||
|
# index of next block (of size offloaded_block_size) to offload
|
||||||
|
self._next_stored_block_idx: dict[ReqId, int] = {}
|
||||||
|
# if GPU prefix caching is enabled,
|
||||||
|
# track loaded blocks to avoid redundant loads
|
||||||
|
self._blocks_being_loaded: set[BlockHash] | None = (
|
||||||
|
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# request ID -> set(block hashes being stored/load)
|
||||||
|
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
|
||||||
|
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
|
||||||
|
|
||||||
|
def _get_block_hashes(
|
||||||
|
self,
|
||||||
|
req: Request,
|
||||||
|
start_idx: int = 0,
|
||||||
|
end_idx: int | None = None,
|
||||||
|
) -> Iterable[BlockHash]:
|
||||||
|
return islice(
|
||||||
|
req.block_hashes,
|
||||||
|
self.block_size_factor * start_idx + self.block_size_factor - 1,
|
||||||
|
self.block_size_factor * end_idx if end_idx else None,
|
||||||
|
self.block_size_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self, request: Request, num_computed_tokens: int
|
||||||
|
) -> tuple[int | None, bool]:
|
||||||
|
"""
|
||||||
|
Get number of new tokens that can be loaded beyond the
|
||||||
|
num_computed_tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): the request object.
|
||||||
|
num_computed_tokens (int): the number of locally
|
||||||
|
computed tokens for this request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple with the following elements:
|
||||||
|
- The number of tokens that can be loaded beyond what is
|
||||||
|
already computed.
|
||||||
|
If None, it means that the connector needs more time to
|
||||||
|
determine the number of matched tokens, and the scheduler
|
||||||
|
should query for this request again later.
|
||||||
|
- `True` if tokens will be loaded asynchronously
|
||||||
|
(between scheduler steps).
|
||||||
|
"""
|
||||||
|
num_blocks = request.num_tokens // self.offloaded_block_size
|
||||||
|
|
||||||
|
assert len(request.block_hashes) // self.block_size_factor == num_blocks
|
||||||
|
block_hashes = self._get_block_hashes(request)
|
||||||
|
|
||||||
|
self.manager.touch(block_hashes)
|
||||||
|
|
||||||
|
full_block_tokens = self.offloaded_block_size * num_blocks
|
||||||
|
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
|
||||||
|
# we can load less than a block, skip
|
||||||
|
return 0, False
|
||||||
|
|
||||||
|
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||||
|
hits = self.manager.lookup(
|
||||||
|
self._get_block_hashes(request, start_idx=start_block_idx)
|
||||||
|
)
|
||||||
|
if hits is None:
|
||||||
|
# indicates a lookup that should be tried later
|
||||||
|
return None, False
|
||||||
|
if hits == 0:
|
||||||
|
return 0, False
|
||||||
|
|
||||||
|
num_hit_tokens = (
|
||||||
|
self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
|
||||||
|
request.request_id,
|
||||||
|
num_hit_tokens,
|
||||||
|
num_computed_tokens,
|
||||||
|
)
|
||||||
|
if num_hit_tokens < self.offloaded_block_size:
|
||||||
|
return 0, False
|
||||||
|
|
||||||
|
if self._blocks_being_loaded:
|
||||||
|
block_hashes = self._get_block_hashes(
|
||||||
|
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(
|
||||||
|
block_hash in self._blocks_being_loaded for block_hash in block_hashes
|
||||||
|
):
|
||||||
|
# hit blocks are being loaded, delay request
|
||||||
|
logger.debug(
|
||||||
|
"Delaying request %s since some of its blocks are already"
|
||||||
|
" being loaded",
|
||||||
|
request.request_id,
|
||||||
|
)
|
||||||
|
return None, False
|
||||||
|
|
||||||
|
return num_hit_tokens, True
|
||||||
|
|
||||||
|
def update_state_after_alloc(
|
||||||
|
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
|
||||||
|
):
|
||||||
|
self._requests[request.request_id] = request
|
||||||
|
# the block ids are updated in _get_reqs_to_store
|
||||||
|
self._request_block_ids[request.request_id] = []
|
||||||
|
|
||||||
|
if num_external_tokens == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
block_groups = blocks.get_block_ids()
|
||||||
|
block_ids = block_groups[0]
|
||||||
|
|
||||||
|
num_computed_gpu_blocks = sum(
|
||||||
|
block.block_hash is not None for block in blocks.blocks[0]
|
||||||
|
)
|
||||||
|
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
|
||||||
|
full_block_tokens = num_computed_tokens + num_external_tokens
|
||||||
|
assert full_block_tokens % self.offloaded_block_size == 0
|
||||||
|
|
||||||
|
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
|
||||||
|
assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size
|
||||||
|
|
||||||
|
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||||
|
num_blocks = full_block_tokens // self.offloaded_block_size
|
||||||
|
|
||||||
|
assert len(request.block_hashes) // self.block_size_factor >= num_blocks
|
||||||
|
block_hashes = self._get_block_hashes(
|
||||||
|
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||||
|
)
|
||||||
|
|
||||||
|
src_spec = self.manager.prepare_load(block_hashes)
|
||||||
|
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
|
||||||
|
|
||||||
|
block_hashes = self._get_block_hashes(
|
||||||
|
request, start_idx=start_block_idx, end_idx=num_blocks
|
||||||
|
)
|
||||||
|
|
||||||
|
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
|
||||||
|
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
|
||||||
|
req_blocks_being_loaded.update(block_hashes)
|
||||||
|
self._next_stored_block_idx[request.request_id] = num_blocks
|
||||||
|
|
||||||
|
if self._blocks_being_loaded is not None:
|
||||||
|
self._blocks_being_loaded.update(req_blocks_being_loaded)
|
||||||
|
|
||||||
|
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
|
||||||
|
reqs_to_store: dict[ReqId, TransferSpec] = {}
|
||||||
|
# iterate over both new and cached requests
|
||||||
|
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
|
||||||
|
if preempted:
|
||||||
|
self._request_block_ids[req_id] = []
|
||||||
|
|
||||||
|
if new_block_id_groups:
|
||||||
|
new_block_ids = new_block_id_groups[0]
|
||||||
|
self._request_block_ids[req_id] += new_block_ids
|
||||||
|
|
||||||
|
block_ids = self._request_block_ids[req_id]
|
||||||
|
|
||||||
|
req = self._requests[req_id]
|
||||||
|
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||||
|
expected_tokens = req.num_computed_tokens + new_tokens
|
||||||
|
# with async scheduling, some tokens may be missing
|
||||||
|
total_tokens = min(expected_tokens, req.num_tokens)
|
||||||
|
num_blocks = total_tokens // self.offloaded_block_size
|
||||||
|
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
|
||||||
|
num_new_blocks = num_blocks - start_block_idx
|
||||||
|
|
||||||
|
if num_new_blocks <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_gpu_blocks = num_blocks * self.block_size_factor
|
||||||
|
assert len(req.block_hashes) >= num_gpu_blocks
|
||||||
|
|
||||||
|
new_block_hashes = self._get_block_hashes(
|
||||||
|
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||||
|
)
|
||||||
|
store_output = self.manager.prepare_store(new_block_hashes)
|
||||||
|
if store_output is None:
|
||||||
|
logger.warning(
|
||||||
|
"Request %s: cannot store %s blocks", req_id, num_new_blocks
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._next_stored_block_idx[req_id] = num_blocks
|
||||||
|
|
||||||
|
if not store_output.block_hashes_to_store:
|
||||||
|
continue
|
||||||
|
block_hashes_to_store = set(store_output.block_hashes_to_store)
|
||||||
|
|
||||||
|
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
|
||||||
|
self.manager.touch(block_hashes)
|
||||||
|
|
||||||
|
new_block_hashes = self._get_block_hashes(
|
||||||
|
req, start_idx=start_block_idx, end_idx=num_blocks
|
||||||
|
)
|
||||||
|
dst_spec = store_output.store_spec
|
||||||
|
src_block_ids: list[int] = []
|
||||||
|
for idx, blk_hash in enumerate(new_block_hashes):
|
||||||
|
if blk_hash not in block_hashes_to_store:
|
||||||
|
continue
|
||||||
|
offloaded_block_idx = start_block_idx + idx
|
||||||
|
gpu_block_idx = offloaded_block_idx * self.block_size_factor
|
||||||
|
for i in range(self.block_size_factor):
|
||||||
|
src_block_ids.append(block_ids[gpu_block_idx + i])
|
||||||
|
src_spec = GPULoadStoreSpec(src_block_ids)
|
||||||
|
|
||||||
|
reqs_to_store[req_id] = (src_spec, dst_spec)
|
||||||
|
self._reqs_being_stored[req_id] |= block_hashes_to_store
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Request %s offloading %s blocks starting from block #%d",
|
||||||
|
req_id,
|
||||||
|
len(block_hashes_to_store),
|
||||||
|
start_block_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
return reqs_to_store
|
||||||
|
|
||||||
|
def build_connector_meta(
|
||||||
|
self, scheduler_output: SchedulerOutput
|
||||||
|
) -> KVConnectorMetadata:
|
||||||
|
meta = OffloadingConnectorMetadata(
|
||||||
|
reqs_to_load=self._reqs_to_load,
|
||||||
|
reqs_to_store=self._get_reqs_to_store(scheduler_output),
|
||||||
|
reqs_to_flush=scheduler_output.preempted_req_ids,
|
||||||
|
)
|
||||||
|
self._reqs_to_load = {}
|
||||||
|
|
||||||
|
# NOTE (orozery): we should move this logic to update_connector_output
|
||||||
|
# once KVConnectorOutput allows us to report completed transfers
|
||||||
|
for req_id in scheduler_output.preempted_req_ids or ():
|
||||||
|
block_hashes = self._reqs_being_stored.get(req_id)
|
||||||
|
if block_hashes:
|
||||||
|
self.manager.complete_store(block_hashes)
|
||||||
|
block_hashes.clear()
|
||||||
|
|
||||||
|
return meta
|
||||||
|
|
||||||
|
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||||
|
"""
|
||||||
|
Update KVConnector state from worker-side connectors output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector_output (KVConnectorOutput): the worker-side
|
||||||
|
connectors output.
|
||||||
|
"""
|
||||||
|
for req_id in connector_output.finished_sending or []:
|
||||||
|
block_hashes = self._reqs_being_stored.pop(req_id, None)
|
||||||
|
if block_hashes:
|
||||||
|
self.manager.complete_store(block_hashes)
|
||||||
|
|
||||||
|
for req_id in connector_output.finished_recving or []:
|
||||||
|
block_hashes = self._reqs_being_loaded.pop(req_id, None)
|
||||||
|
if block_hashes:
|
||||||
|
if self._blocks_being_loaded:
|
||||||
|
self._blocks_being_loaded.difference_update(block_hashes)
|
||||||
|
self.manager.complete_load(block_hashes)
|
||||||
|
|
||||||
|
def request_finished(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
block_ids: list[int],
|
||||||
|
) -> tuple[bool, dict[str, Any] | None]:
|
||||||
|
"""
|
||||||
|
Called when a request has finished, before its blocks are freed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the request is being saved/sent asynchronously and blocks
|
||||||
|
should not be freed until the request_id is returned from
|
||||||
|
get_finished().
|
||||||
|
Optional KVTransferParams to be included in the request outputs
|
||||||
|
returned by the engine.
|
||||||
|
"""
|
||||||
|
req_id = request.request_id
|
||||||
|
self._requests.pop(req_id, None)
|
||||||
|
self._request_block_ids.pop(req_id, None)
|
||||||
|
|
||||||
|
# TODO(orozery): possibly kickoff offload for last block
|
||||||
|
# which may have been deferred due to async scheduling
|
||||||
|
self._next_stored_block_idx.pop(req_id, None)
|
||||||
|
|
||||||
|
request_being_stored = req_id in self._reqs_being_stored
|
||||||
|
return request_being_stored, None
|
||||||
|
|
||||||
|
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||||
|
"""Take the KV cache events from the connector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of KV cache events.
|
||||||
|
"""
|
||||||
|
for event in self.manager.take_events():
|
||||||
|
if event.removed:
|
||||||
|
yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
|
||||||
|
else:
|
||||||
|
yield BlockStored(
|
||||||
|
block_hashes=event.block_hashes,
|
||||||
|
parent_block_hash=None,
|
||||||
|
token_ids=[],
|
||||||
|
lora_id=None,
|
||||||
|
block_size=event.block_size,
|
||||||
|
medium=event.medium,
|
||||||
|
lora_name=None,
|
||||||
|
)
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import get_layers_from_vllm_config
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||||
|
KVConnectorStats,
|
||||||
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
|
||||||
|
OffloadingConnectorMetadata,
|
||||||
|
ReqId,
|
||||||
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import (
|
||||||
|
OffloadingConnectorStats,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
from vllm.v1.attention.backend import AttentionBackend
|
||||||
|
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||||
|
from vllm.v1.kv_offload.worker.worker import (
|
||||||
|
OffloadingWorker,
|
||||||
|
TransferSpec,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OffloadingConnectorWorker:
|
||||||
|
"""Implementation of Worker side methods"""
|
||||||
|
|
||||||
|
def __init__(self, spec: OffloadingSpec):
|
||||||
|
self.spec = spec
|
||||||
|
self.worker = OffloadingWorker()
|
||||||
|
|
||||||
|
self._job_counter = 0
|
||||||
|
|
||||||
|
self.kv_connector_stats = OffloadingConnectorStats()
|
||||||
|
# req_id -> (job_id, store)
|
||||||
|
self._jobs: dict[int, tuple[ReqId, bool]] = {}
|
||||||
|
# req_id -> active job IDs
|
||||||
|
self._load_job: dict[ReqId, int] = {}
|
||||||
|
# req_id -> set(active job IDs)
|
||||||
|
self._store_jobs = defaultdict[ReqId, set[int]](set)
|
||||||
|
# list of store jobs pending submission (job_id, transfer_spec)
|
||||||
|
self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
|
||||||
|
|
||||||
|
self._finished_reqs_waiting_for_store: set[ReqId] = set()
|
||||||
|
|
||||||
|
def _generate_job_id(self) -> int:
|
||||||
|
job_id = self._job_counter
|
||||||
|
self._job_counter = job_id + 1
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
def _register_handlers(
|
||||||
|
self,
|
||||||
|
kv_caches: dict[str, torch.Tensor],
|
||||||
|
attn_backends: dict[str, type[AttentionBackend]],
|
||||||
|
):
|
||||||
|
for src_cls, dst_cls, handler in self.spec.get_handlers(
|
||||||
|
kv_caches, attn_backends
|
||||||
|
):
|
||||||
|
self.worker.register_handler(src_cls, dst_cls, handler)
|
||||||
|
|
||||||
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
|
layer_names = list(kv_caches.keys())
|
||||||
|
layers = get_layers_from_vllm_config(
|
||||||
|
self.spec.vllm_config,
|
||||||
|
AttentionLayerBase, # type: ignore[type-abstract]
|
||||||
|
layer_names,
|
||||||
|
)
|
||||||
|
attn_backends = {
|
||||||
|
layer_name: layers[layer_name].get_attn_backend()
|
||||||
|
for layer_name in layer_names
|
||||||
|
}
|
||||||
|
self._register_handlers(kv_caches, attn_backends)
|
||||||
|
|
||||||
|
def register_cross_layers_kv_cache(
|
||||||
|
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||||
|
):
|
||||||
|
cross_layer_name = "ALL_LAYERS"
|
||||||
|
kv_caches = {cross_layer_name: kv_cache}
|
||||||
|
attn_backends = {cross_layer_name: attn_backend}
|
||||||
|
self._register_handlers(kv_caches, attn_backends)
|
||||||
|
|
||||||
|
def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
|
||||||
|
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||||
|
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||||
|
assert success
|
||||||
|
self._unsubmitted_store_jobs.clear()
|
||||||
|
|
||||||
|
for req_id in kv_connector_metadata.reqs_to_flush or ():
|
||||||
|
job_ids = self._store_jobs.get(req_id)
|
||||||
|
if job_ids:
|
||||||
|
self.worker.wait(job_ids)
|
||||||
|
|
||||||
|
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
|
||||||
|
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||||
|
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||||
|
assert success
|
||||||
|
self._unsubmitted_store_jobs.clear()
|
||||||
|
|
||||||
|
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
||||||
|
job_id = self._generate_job_id()
|
||||||
|
self._jobs[job_id] = (req_id, False)
|
||||||
|
assert req_id not in self._load_job
|
||||||
|
self._load_job[req_id] = job_id
|
||||||
|
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||||
|
assert success
|
||||||
|
|
||||||
|
def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
|
||||||
|
for req_id, transfer_spec in metadata.reqs_to_store.items():
|
||||||
|
job_id = self._generate_job_id()
|
||||||
|
self._jobs[job_id] = (req_id, True)
|
||||||
|
self._store_jobs[req_id].add(job_id)
|
||||||
|
# NOTE(orozery): defer the store to the beginning of the next engine step,
|
||||||
|
# so that offloading starts AFTER transfers related to token sampling,
|
||||||
|
# thereby avoiding delays to token generation due to offloading.
|
||||||
|
self._unsubmitted_store_jobs.append((job_id, transfer_spec))
|
||||||
|
|
||||||
|
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||||
|
"""
|
||||||
|
Notifies worker-side connector ids of requests that have
|
||||||
|
finished generating tokens.
|
||||||
|
Returns a list of request IDs that finished loading or storing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ids of requests that have finished asynchronous transfer
|
||||||
|
tuple of (sending/saving ids, recving/loading ids).
|
||||||
|
"""
|
||||||
|
finished_sending = set()
|
||||||
|
finished_recving = set()
|
||||||
|
for transfer_result in self.worker.get_finished():
|
||||||
|
# we currently do not support job failures
|
||||||
|
job_id = transfer_result.job_id
|
||||||
|
assert transfer_result.success
|
||||||
|
req_id, store = self._jobs.pop(job_id)
|
||||||
|
if (
|
||||||
|
transfer_result.transfer_time
|
||||||
|
and transfer_result.transfer_size is not None
|
||||||
|
and transfer_result.transfer_type is not None
|
||||||
|
):
|
||||||
|
self.kv_connector_stats.record_transfer(
|
||||||
|
num_bytes=transfer_result.transfer_size,
|
||||||
|
time=transfer_result.transfer_time,
|
||||||
|
transfer_type=transfer_result.transfer_type,
|
||||||
|
)
|
||||||
|
if store:
|
||||||
|
req_jobs = self._store_jobs[req_id]
|
||||||
|
req_jobs.remove(job_id)
|
||||||
|
if req_jobs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if req_id in self._finished_reqs_waiting_for_store:
|
||||||
|
self._finished_reqs_waiting_for_store.remove(req_id)
|
||||||
|
finished_sending.add(req_id)
|
||||||
|
del self._store_jobs[req_id]
|
||||||
|
else:
|
||||||
|
req_job = self._load_job[req_id]
|
||||||
|
assert job_id == req_job
|
||||||
|
del self._load_job[req_id]
|
||||||
|
finished_recving.add(req_id)
|
||||||
|
|
||||||
|
for req_id in finished_req_ids:
|
||||||
|
pending_req_jobs = self._store_jobs.get(req_id)
|
||||||
|
if pending_req_jobs:
|
||||||
|
self._finished_reqs_waiting_for_store.add(req_id)
|
||||||
|
elif pending_req_jobs is not None:
|
||||||
|
finished_sending.add(req_id)
|
||||||
|
del self._store_jobs[req_id]
|
||||||
|
|
||||||
|
return finished_sending, finished_recving
|
||||||
|
|
||||||
|
def get_kv_connector_stats(self) -> KVConnectorStats | None:
|
||||||
|
"""
|
||||||
|
Get the KV transfer stats for the connector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.kv_connector_stats.is_empty():
|
||||||
|
return None
|
||||||
|
# Clear stats for next iteration
|
||||||
|
kv_connector_stats = self.kv_connector_stats
|
||||||
|
self.kv_connector_stats = OffloadingConnectorStats()
|
||||||
|
return kv_connector_stats
|
||||||
@@ -1,16 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections import defaultdict
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
|
||||||
from itertools import islice
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
from vllm.distributed.kv_events import KVCacheEvent
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
KVConnectorRole,
|
KVConnectorRole,
|
||||||
@@ -22,97 +18,28 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
|||||||
PromMetric,
|
PromMetric,
|
||||||
PromMetricT,
|
PromMetricT,
|
||||||
)
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
|
||||||
|
OffloadingConnectorMetadata,
|
||||||
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.metrics import (
|
||||||
|
OffloadingConnectorStats,
|
||||||
|
OffloadPromMetrics,
|
||||||
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler import (
|
||||||
|
OffloadingConnectorScheduler,
|
||||||
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.worker import (
|
||||||
|
OffloadingConnectorWorker,
|
||||||
|
)
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
||||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.kv_offload.abstract import OffloadingManager
|
|
||||||
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
|
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
|
||||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
|
||||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
|
||||||
from vllm.v1.kv_offload.worker.worker import (
|
|
||||||
OffloadingWorker,
|
|
||||||
TransferSpec,
|
|
||||||
TransferType,
|
|
||||||
)
|
|
||||||
from vllm.v1.outputs import KVConnectorOutput
|
from vllm.v1.outputs import KVConnectorOutput
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
ReqId = str
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OffloadingOperationMetrics:
|
|
||||||
op_size: int
|
|
||||||
op_time: float
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OffloadingConnectorStats(KVConnectorStats):
|
|
||||||
def __post_init__(self):
|
|
||||||
if not self.data:
|
|
||||||
# Empty container init, no data is passed in.
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.data: dict[str, list[OffloadingOperationMetrics]] = {}
|
|
||||||
|
|
||||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
|
||||||
if not other.is_empty():
|
|
||||||
for k, v in other.data.items():
|
|
||||||
if k not in self.data:
|
|
||||||
self.data[k] = v
|
|
||||||
else:
|
|
||||||
accumulator = self.data[k]
|
|
||||||
assert isinstance(accumulator, list)
|
|
||||||
accumulator.extend(v)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def reduce(self) -> dict[str, int | float]:
|
|
||||||
"""
|
|
||||||
Reduce the observations collected during a time interval to one or
|
|
||||||
more representative values (eg avg/median/sum of the series).
|
|
||||||
This is meant to be called by the logger to produce a summary of the
|
|
||||||
stats for the last time interval.
|
|
||||||
"""
|
|
||||||
return_dict: dict[str, int | float] = {}
|
|
||||||
for transfer_type, ops_list in self.data.items():
|
|
||||||
assert isinstance(ops_list, list)
|
|
||||||
total_bytes = 0
|
|
||||||
total_time = 0.0
|
|
||||||
for op in ops_list:
|
|
||||||
assert isinstance(op, dict)
|
|
||||||
total_bytes += op["op_size"]
|
|
||||||
total_time += op["op_time"]
|
|
||||||
return_dict[f"{transfer_type}_total_bytes"] = total_bytes
|
|
||||||
return_dict[f"{transfer_type}_total_time"] = total_time
|
|
||||||
return return_dict
|
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
|
||||||
return not self.data
|
|
||||||
|
|
||||||
def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType):
|
|
||||||
src, dst = transfer_type
|
|
||||||
transfer_type_key = src + "_to_" + dst
|
|
||||||
op = OffloadingOperationMetrics(num_bytes, time)
|
|
||||||
if transfer_type_key in self.data:
|
|
||||||
self.data[transfer_type_key].append(op)
|
|
||||||
else:
|
|
||||||
self.data[transfer_type_key] = [op]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
|
||||||
reqs_to_load: dict[ReqId, TransferSpec]
|
|
||||||
reqs_to_store: dict[ReqId, TransferSpec]
|
|
||||||
reqs_to_flush: set[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class OffloadingConnector(KVConnectorBase_V1):
|
class OffloadingConnector(KVConnectorBase_V1):
|
||||||
@property
|
@property
|
||||||
@@ -242,571 +169,3 @@ class OffloadingConnector(KVConnectorBase_V1):
|
|||||||
return OffloadPromMetrics(
|
return OffloadPromMetrics(
|
||||||
vllm_config, metric_types, labelnames, per_engine_labelvalues
|
vllm_config, metric_types, labelnames, per_engine_labelvalues
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OffloadingConnectorScheduler:
|
|
||||||
"""Implementation of Scheduler side methods"""
|
|
||||||
|
|
||||||
def __init__(self, spec: OffloadingSpec):
|
|
||||||
assert len(spec.gpu_block_size) == 1
|
|
||||||
self.gpu_block_size = spec.gpu_block_size[0]
|
|
||||||
self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor
|
|
||||||
self.block_size_factor = spec.block_size_factor
|
|
||||||
self.manager: OffloadingManager = spec.get_manager()
|
|
||||||
|
|
||||||
self._requests: dict[ReqId, Request] = {}
|
|
||||||
# list of GPU block IDs per request
|
|
||||||
self._request_block_ids: dict[ReqId, list[int]] = {}
|
|
||||||
# requests to load for the current scheduler step
|
|
||||||
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
|
|
||||||
# request blocks are stored in order
|
|
||||||
# index of next block (of size offloaded_block_size) to offload
|
|
||||||
self._next_stored_block_idx: dict[ReqId, int] = {}
|
|
||||||
# if GPU prefix caching is enabled,
|
|
||||||
# track loaded blocks to avoid redundant loads
|
|
||||||
self._blocks_being_loaded: set[BlockHash] | None = (
|
|
||||||
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# request ID -> set(block hashes being stored/load)
|
|
||||||
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
|
|
||||||
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
|
|
||||||
|
|
||||||
def _get_block_hashes(
|
|
||||||
self,
|
|
||||||
req: Request,
|
|
||||||
start_idx: int = 0,
|
|
||||||
end_idx: int | None = None,
|
|
||||||
) -> Iterable[BlockHash]:
|
|
||||||
return islice(
|
|
||||||
req.block_hashes,
|
|
||||||
self.block_size_factor * start_idx + self.block_size_factor - 1,
|
|
||||||
self.block_size_factor * end_idx if end_idx else None,
|
|
||||||
self.block_size_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
|
||||||
self, request: Request, num_computed_tokens: int
|
|
||||||
) -> tuple[int | None, bool]:
|
|
||||||
"""
|
|
||||||
Get number of new tokens that can be loaded beyond the
|
|
||||||
num_computed_tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request (Request): the request object.
|
|
||||||
num_computed_tokens (int): the number of locally
|
|
||||||
computed tokens for this request
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple with the following elements:
|
|
||||||
- The number of tokens that can be loaded beyond what is
|
|
||||||
already computed.
|
|
||||||
If None, it means that the connector needs more time to
|
|
||||||
determine the number of matched tokens, and the scheduler
|
|
||||||
should query for this request again later.
|
|
||||||
- `True` if tokens will be loaded asynchronously
|
|
||||||
(between scheduler steps).
|
|
||||||
"""
|
|
||||||
num_blocks = request.num_tokens // self.offloaded_block_size
|
|
||||||
|
|
||||||
assert len(request.block_hashes) // self.block_size_factor == num_blocks
|
|
||||||
block_hashes = self._get_block_hashes(request)
|
|
||||||
|
|
||||||
self.manager.touch(block_hashes)
|
|
||||||
|
|
||||||
full_block_tokens = self.offloaded_block_size * num_blocks
|
|
||||||
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
|
|
||||||
# we can load less than a block, skip
|
|
||||||
return 0, False
|
|
||||||
|
|
||||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
|
||||||
hits = self.manager.lookup(
|
|
||||||
self._get_block_hashes(request, start_idx=start_block_idx)
|
|
||||||
)
|
|
||||||
if hits is None:
|
|
||||||
# indicates a lookup that should be tried later
|
|
||||||
return None, False
|
|
||||||
if hits == 0:
|
|
||||||
return 0, False
|
|
||||||
|
|
||||||
num_hit_tokens = (
|
|
||||||
self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
|
|
||||||
request.request_id,
|
|
||||||
num_hit_tokens,
|
|
||||||
num_computed_tokens,
|
|
||||||
)
|
|
||||||
if num_hit_tokens < self.offloaded_block_size:
|
|
||||||
return 0, False
|
|
||||||
|
|
||||||
if self._blocks_being_loaded:
|
|
||||||
block_hashes = self._get_block_hashes(
|
|
||||||
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
|
|
||||||
)
|
|
||||||
|
|
||||||
if any(
|
|
||||||
block_hash in self._blocks_being_loaded for block_hash in block_hashes
|
|
||||||
):
|
|
||||||
# hit blocks are being loaded, delay request
|
|
||||||
logger.debug(
|
|
||||||
"Delaying request %s since some of its blocks are already"
|
|
||||||
" being loaded",
|
|
||||||
request.request_id,
|
|
||||||
)
|
|
||||||
return None, False
|
|
||||||
|
|
||||||
return num_hit_tokens, True
|
|
||||||
|
|
||||||
def update_state_after_alloc(
|
|
||||||
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
|
|
||||||
):
|
|
||||||
self._requests[request.request_id] = request
|
|
||||||
# the block ids are updated in _get_reqs_to_store
|
|
||||||
self._request_block_ids[request.request_id] = []
|
|
||||||
|
|
||||||
if num_external_tokens == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
block_groups = blocks.get_block_ids()
|
|
||||||
block_ids = block_groups[0]
|
|
||||||
|
|
||||||
num_computed_gpu_blocks = sum(
|
|
||||||
block.block_hash is not None for block in blocks.blocks[0]
|
|
||||||
)
|
|
||||||
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
|
|
||||||
full_block_tokens = num_computed_tokens + num_external_tokens
|
|
||||||
assert full_block_tokens % self.offloaded_block_size == 0
|
|
||||||
|
|
||||||
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
|
|
||||||
assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size
|
|
||||||
|
|
||||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
|
||||||
num_blocks = full_block_tokens // self.offloaded_block_size
|
|
||||||
|
|
||||||
assert len(request.block_hashes) // self.block_size_factor >= num_blocks
|
|
||||||
block_hashes = self._get_block_hashes(
|
|
||||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
|
||||||
)
|
|
||||||
|
|
||||||
src_spec = self.manager.prepare_load(block_hashes)
|
|
||||||
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
|
|
||||||
|
|
||||||
block_hashes = self._get_block_hashes(
|
|
||||||
request, start_idx=start_block_idx, end_idx=num_blocks
|
|
||||||
)
|
|
||||||
|
|
||||||
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
|
|
||||||
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
|
|
||||||
req_blocks_being_loaded.update(block_hashes)
|
|
||||||
self._next_stored_block_idx[request.request_id] = num_blocks
|
|
||||||
|
|
||||||
if self._blocks_being_loaded is not None:
|
|
||||||
self._blocks_being_loaded.update(req_blocks_being_loaded)
|
|
||||||
|
|
||||||
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
|
|
||||||
reqs_to_store: dict[ReqId, TransferSpec] = {}
|
|
||||||
# iterate over both new and cached requests
|
|
||||||
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
|
|
||||||
if preempted:
|
|
||||||
self._request_block_ids[req_id] = []
|
|
||||||
|
|
||||||
if new_block_id_groups:
|
|
||||||
new_block_ids = new_block_id_groups[0]
|
|
||||||
self._request_block_ids[req_id] += new_block_ids
|
|
||||||
|
|
||||||
block_ids = self._request_block_ids[req_id]
|
|
||||||
|
|
||||||
req = self._requests[req_id]
|
|
||||||
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
||||||
expected_tokens = req.num_computed_tokens + new_tokens
|
|
||||||
# with async scheduling, some tokens may be missing
|
|
||||||
total_tokens = min(expected_tokens, req.num_tokens)
|
|
||||||
num_blocks = total_tokens // self.offloaded_block_size
|
|
||||||
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
|
|
||||||
num_new_blocks = num_blocks - start_block_idx
|
|
||||||
|
|
||||||
if num_new_blocks <= 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
num_gpu_blocks = num_blocks * self.block_size_factor
|
|
||||||
assert len(req.block_hashes) >= num_gpu_blocks
|
|
||||||
|
|
||||||
new_block_hashes = self._get_block_hashes(
|
|
||||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
|
||||||
)
|
|
||||||
store_output = self.manager.prepare_store(new_block_hashes)
|
|
||||||
if store_output is None:
|
|
||||||
logger.warning(
|
|
||||||
"Request %s: cannot store %s blocks", req_id, num_new_blocks
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._next_stored_block_idx[req_id] = num_blocks
|
|
||||||
|
|
||||||
if not store_output.block_hashes_to_store:
|
|
||||||
continue
|
|
||||||
block_hashes_to_store = set(store_output.block_hashes_to_store)
|
|
||||||
|
|
||||||
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
|
|
||||||
self.manager.touch(block_hashes)
|
|
||||||
|
|
||||||
new_block_hashes = self._get_block_hashes(
|
|
||||||
req, start_idx=start_block_idx, end_idx=num_blocks
|
|
||||||
)
|
|
||||||
dst_spec = store_output.store_spec
|
|
||||||
src_block_ids: list[int] = []
|
|
||||||
for idx, blk_hash in enumerate(new_block_hashes):
|
|
||||||
if blk_hash not in block_hashes_to_store:
|
|
||||||
continue
|
|
||||||
offloaded_block_idx = start_block_idx + idx
|
|
||||||
gpu_block_idx = offloaded_block_idx * self.block_size_factor
|
|
||||||
for i in range(self.block_size_factor):
|
|
||||||
src_block_ids.append(block_ids[gpu_block_idx + i])
|
|
||||||
src_spec = GPULoadStoreSpec(src_block_ids)
|
|
||||||
|
|
||||||
reqs_to_store[req_id] = (src_spec, dst_spec)
|
|
||||||
self._reqs_being_stored[req_id] |= block_hashes_to_store
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Request %s offloading %s blocks starting from block #%d",
|
|
||||||
req_id,
|
|
||||||
len(block_hashes_to_store),
|
|
||||||
start_block_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
return reqs_to_store
|
|
||||||
|
|
||||||
def build_connector_meta(
|
|
||||||
self, scheduler_output: SchedulerOutput
|
|
||||||
) -> KVConnectorMetadata:
|
|
||||||
meta = OffloadingConnectorMetadata(
|
|
||||||
reqs_to_load=self._reqs_to_load,
|
|
||||||
reqs_to_store=self._get_reqs_to_store(scheduler_output),
|
|
||||||
reqs_to_flush=scheduler_output.preempted_req_ids,
|
|
||||||
)
|
|
||||||
self._reqs_to_load = {}
|
|
||||||
|
|
||||||
# NOTE (orozery): we should move this logic to update_connector_output
|
|
||||||
# once KVConnectorOutput allows us to report completed transfers
|
|
||||||
for req_id in scheduler_output.preempted_req_ids or ():
|
|
||||||
block_hashes = self._reqs_being_stored.get(req_id)
|
|
||||||
if block_hashes:
|
|
||||||
self.manager.complete_store(block_hashes)
|
|
||||||
block_hashes.clear()
|
|
||||||
|
|
||||||
return meta
|
|
||||||
|
|
||||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
|
||||||
"""
|
|
||||||
Update KVConnector state from worker-side connectors output.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connector_output (KVConnectorOutput): the worker-side
|
|
||||||
connectors output.
|
|
||||||
"""
|
|
||||||
for req_id in connector_output.finished_sending or []:
|
|
||||||
block_hashes = self._reqs_being_stored.pop(req_id, None)
|
|
||||||
if block_hashes:
|
|
||||||
self.manager.complete_store(block_hashes)
|
|
||||||
|
|
||||||
for req_id in connector_output.finished_recving or []:
|
|
||||||
block_hashes = self._reqs_being_loaded.pop(req_id, None)
|
|
||||||
if block_hashes:
|
|
||||||
if self._blocks_being_loaded:
|
|
||||||
self._blocks_being_loaded.difference_update(block_hashes)
|
|
||||||
self.manager.complete_load(block_hashes)
|
|
||||||
|
|
||||||
def request_finished(
|
|
||||||
self,
|
|
||||||
request: Request,
|
|
||||||
block_ids: list[int],
|
|
||||||
) -> tuple[bool, dict[str, Any] | None]:
|
|
||||||
"""
|
|
||||||
Called when a request has finished, before its blocks are freed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the request is being saved/sent asynchronously and blocks
|
|
||||||
should not be freed until the request_id is returned from
|
|
||||||
get_finished().
|
|
||||||
Optional KVTransferParams to be included in the request outputs
|
|
||||||
returned by the engine.
|
|
||||||
"""
|
|
||||||
req_id = request.request_id
|
|
||||||
self._requests.pop(req_id, None)
|
|
||||||
self._request_block_ids.pop(req_id, None)
|
|
||||||
|
|
||||||
# TODO(orozery): possibly kickoff offload for last block
|
|
||||||
# which may have been deferred due to async scheduling
|
|
||||||
self._next_stored_block_idx.pop(req_id, None)
|
|
||||||
|
|
||||||
request_being_stored = req_id in self._reqs_being_stored
|
|
||||||
return request_being_stored, None
|
|
||||||
|
|
||||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
|
||||||
"""Take the KV cache events from the connector.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of KV cache events.
|
|
||||||
"""
|
|
||||||
for event in self.manager.take_events():
|
|
||||||
if event.removed:
|
|
||||||
yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
|
|
||||||
else:
|
|
||||||
yield BlockStored(
|
|
||||||
block_hashes=event.block_hashes,
|
|
||||||
parent_block_hash=None,
|
|
||||||
token_ids=[],
|
|
||||||
lora_id=None,
|
|
||||||
block_size=event.block_size,
|
|
||||||
medium=event.medium,
|
|
||||||
lora_name=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OffloadingConnectorWorker:
|
|
||||||
"""Implementation of Worker side methods"""
|
|
||||||
|
|
||||||
def __init__(self, spec: OffloadingSpec):
|
|
||||||
self.spec = spec
|
|
||||||
self.worker = OffloadingWorker()
|
|
||||||
|
|
||||||
self._job_counter = 0
|
|
||||||
|
|
||||||
self.kv_connector_stats = OffloadingConnectorStats()
|
|
||||||
# req_id -> (job_id, store)
|
|
||||||
self._jobs: dict[int, tuple[ReqId, bool]] = {}
|
|
||||||
# req_id -> active job IDs
|
|
||||||
self._load_job: dict[ReqId, int] = {}
|
|
||||||
# req_id -> set(active job IDs)
|
|
||||||
self._store_jobs = defaultdict[ReqId, set[int]](set)
|
|
||||||
# list of store jobs pending submission (job_id, transfer_spec)
|
|
||||||
self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
|
|
||||||
|
|
||||||
self._finished_reqs_waiting_for_store: set[ReqId] = set()
|
|
||||||
|
|
||||||
def _generate_job_id(self) -> int:
|
|
||||||
job_id = self._job_counter
|
|
||||||
self._job_counter = job_id + 1
|
|
||||||
return job_id
|
|
||||||
|
|
||||||
def _register_handlers(
|
|
||||||
self,
|
|
||||||
kv_caches: dict[str, torch.Tensor],
|
|
||||||
attn_backends: dict[str, type[AttentionBackend]],
|
|
||||||
):
|
|
||||||
for src_cls, dst_cls, handler in self.spec.get_handlers(
|
|
||||||
kv_caches, attn_backends
|
|
||||||
):
|
|
||||||
self.worker.register_handler(src_cls, dst_cls, handler)
|
|
||||||
|
|
||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
||||||
layer_names = list(kv_caches.keys())
|
|
||||||
layers = get_layers_from_vllm_config(
|
|
||||||
self.spec.vllm_config,
|
|
||||||
AttentionLayerBase, # type: ignore[type-abstract]
|
|
||||||
layer_names,
|
|
||||||
)
|
|
||||||
attn_backends = {
|
|
||||||
layer_name: layers[layer_name].get_attn_backend()
|
|
||||||
for layer_name in layer_names
|
|
||||||
}
|
|
||||||
self._register_handlers(kv_caches, attn_backends)
|
|
||||||
|
|
||||||
def register_cross_layers_kv_cache(
|
|
||||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
|
||||||
):
|
|
||||||
cross_layer_name = "ALL_LAYERS"
|
|
||||||
kv_caches = {cross_layer_name: kv_cache}
|
|
||||||
attn_backends = {cross_layer_name: attn_backend}
|
|
||||||
self._register_handlers(kv_caches, attn_backends)
|
|
||||||
|
|
||||||
def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
|
|
||||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
|
||||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
|
||||||
assert success
|
|
||||||
self._unsubmitted_store_jobs.clear()
|
|
||||||
|
|
||||||
for req_id in kv_connector_metadata.reqs_to_flush or ():
|
|
||||||
job_ids = self._store_jobs.get(req_id)
|
|
||||||
if job_ids:
|
|
||||||
self.worker.wait(job_ids)
|
|
||||||
|
|
||||||
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
|
|
||||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
|
||||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
|
||||||
assert success
|
|
||||||
self._unsubmitted_store_jobs.clear()
|
|
||||||
|
|
||||||
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
|
||||||
job_id = self._generate_job_id()
|
|
||||||
self._jobs[job_id] = (req_id, False)
|
|
||||||
assert req_id not in self._load_job
|
|
||||||
self._load_job[req_id] = job_id
|
|
||||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
|
||||||
assert success
|
|
||||||
|
|
||||||
def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
|
|
||||||
for req_id, transfer_spec in metadata.reqs_to_store.items():
|
|
||||||
job_id = self._generate_job_id()
|
|
||||||
self._jobs[job_id] = (req_id, True)
|
|
||||||
self._store_jobs[req_id].add(job_id)
|
|
||||||
# NOTE(orozery): defer the store to the beginning of the next engine step,
|
|
||||||
# so that offloading starts AFTER transfers related to token sampling,
|
|
||||||
# thereby avoiding delays to token generation due to offloading.
|
|
||||||
self._unsubmitted_store_jobs.append((job_id, transfer_spec))
|
|
||||||
|
|
||||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
|
||||||
"""
|
|
||||||
Notifies worker-side connector ids of requests that have
|
|
||||||
finished generating tokens.
|
|
||||||
Returns a list of request IDs that finished loading or storing.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ids of requests that have finished asynchronous transfer
|
|
||||||
tuple of (sending/saving ids, recving/loading ids).
|
|
||||||
"""
|
|
||||||
finished_sending = set()
|
|
||||||
finished_recving = set()
|
|
||||||
for transfer_result in self.worker.get_finished():
|
|
||||||
# we currently do not support job failures
|
|
||||||
job_id = transfer_result.job_id
|
|
||||||
assert transfer_result.success
|
|
||||||
req_id, store = self._jobs.pop(job_id)
|
|
||||||
if (
|
|
||||||
transfer_result.transfer_time
|
|
||||||
and transfer_result.transfer_size is not None
|
|
||||||
and transfer_result.transfer_type is not None
|
|
||||||
):
|
|
||||||
self.kv_connector_stats.record_transfer(
|
|
||||||
num_bytes=transfer_result.transfer_size,
|
|
||||||
time=transfer_result.transfer_time,
|
|
||||||
transfer_type=transfer_result.transfer_type,
|
|
||||||
)
|
|
||||||
if store:
|
|
||||||
req_jobs = self._store_jobs[req_id]
|
|
||||||
req_jobs.remove(job_id)
|
|
||||||
if req_jobs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if req_id in self._finished_reqs_waiting_for_store:
|
|
||||||
self._finished_reqs_waiting_for_store.remove(req_id)
|
|
||||||
finished_sending.add(req_id)
|
|
||||||
del self._store_jobs[req_id]
|
|
||||||
else:
|
|
||||||
req_job = self._load_job[req_id]
|
|
||||||
assert job_id == req_job
|
|
||||||
del self._load_job[req_id]
|
|
||||||
finished_recving.add(req_id)
|
|
||||||
|
|
||||||
for req_id in finished_req_ids:
|
|
||||||
pending_req_jobs = self._store_jobs.get(req_id)
|
|
||||||
if pending_req_jobs:
|
|
||||||
self._finished_reqs_waiting_for_store.add(req_id)
|
|
||||||
elif pending_req_jobs is not None:
|
|
||||||
finished_sending.add(req_id)
|
|
||||||
del self._store_jobs[req_id]
|
|
||||||
|
|
||||||
return finished_sending, finished_recving
|
|
||||||
|
|
||||||
def get_kv_connector_stats(self) -> KVConnectorStats | None:
|
|
||||||
"""
|
|
||||||
Get the KV transfer stats for the connector.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.kv_connector_stats.is_empty():
|
|
||||||
return None
|
|
||||||
# Clear stats for next iteration
|
|
||||||
kv_connector_stats = self.kv_connector_stats
|
|
||||||
self.kv_connector_stats = OffloadingConnectorStats()
|
|
||||||
return kv_connector_stats
|
|
||||||
|
|
||||||
|
|
||||||
class OffloadPromMetrics(KVConnectorPromMetrics):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
|
||||||
labelnames: list[str],
|
|
||||||
per_engine_labelvalues: dict[int, list[object]],
|
|
||||||
):
|
|
||||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
|
||||||
# (engine_idx, transfer_type) -> (metric with bounded labels)
|
|
||||||
self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {}
|
|
||||||
self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {}
|
|
||||||
self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {}
|
|
||||||
buckets = [ # In bytes
|
|
||||||
1e6,
|
|
||||||
5e6,
|
|
||||||
10e6,
|
|
||||||
20e6,
|
|
||||||
40e6,
|
|
||||||
60e6,
|
|
||||||
80e6,
|
|
||||||
100e6,
|
|
||||||
150e6,
|
|
||||||
200e6,
|
|
||||||
]
|
|
||||||
|
|
||||||
self._counter_kv_bytes = self._counter_cls(
|
|
||||||
name="vllm:kv_offload_total_bytes",
|
|
||||||
documentation="Number of bytes offloaded by KV connector",
|
|
||||||
labelnames=labelnames + ["transfer_type"],
|
|
||||||
)
|
|
||||||
|
|
||||||
self._counter_kv_transfer_time = self._counter_cls(
|
|
||||||
name="vllm:kv_offload_total_time",
|
|
||||||
documentation="Total time measured by all KV offloading operations",
|
|
||||||
labelnames=labelnames + ["transfer_type"],
|
|
||||||
)
|
|
||||||
|
|
||||||
self._histogram_transfer_size = self._histogram_cls(
|
|
||||||
name="vllm:kv_offload_size",
|
|
||||||
documentation="Histogram of KV offload transfer size, in bytes.",
|
|
||||||
buckets=buckets[:],
|
|
||||||
labelnames=labelnames + ["transfer_type"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
|
||||||
"""
|
|
||||||
Observe transfer statistics from the new data structure.
|
|
||||||
transfer_stats_data is expected to be a dict where:
|
|
||||||
- keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
|
|
||||||
- values are lists of OffloadingOperationMetrics objects
|
|
||||||
"""
|
|
||||||
|
|
||||||
for transfer_type, ops in transfer_stats_data.items():
|
|
||||||
# Cache:
|
|
||||||
if (engine_idx, transfer_type) not in self.histogram_transfer_size:
|
|
||||||
self.histogram_transfer_size[(engine_idx, transfer_type)] = (
|
|
||||||
self._histogram_transfer_size.labels(
|
|
||||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.counter_kv_bytes[(engine_idx, transfer_type)] = (
|
|
||||||
self._counter_kv_bytes.labels(
|
|
||||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.counter_kv_transfer_time[(engine_idx, transfer_type)] = (
|
|
||||||
self._counter_kv_transfer_time.labels(
|
|
||||||
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process ops:
|
|
||||||
assert isinstance(ops, list)
|
|
||||||
for op in ops: # ops is a list of serialized OffloadingOperationMetrics
|
|
||||||
assert isinstance(op, dict)
|
|
||||||
# Observe size histogram
|
|
||||||
self.histogram_transfer_size[(engine_idx, transfer_type)].observe(
|
|
||||||
op["op_size"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Increment byte and time counters
|
|
||||||
self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"])
|
|
||||||
|
|
||||||
self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc(
|
|
||||||
op["op_time"]
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user