[Metrics] [KVConnector] Add Offloading Connector metrics (#27942)

Added queries and hits metrics for the Offloading Connector.

Also added timing metrics for store and load operations, which take the
average time it takes to load/store, per-token.

The metrics are available from Prometheus and from the StatLogger.

Signed-off-by: omerpaz95 <omerpaz95@gmail.com>
Co-authored-by: Omer Paz <Omer.Paz@ibm.com>
This commit is contained in:
omerpaz95
2026-01-27 15:34:49 +02:00
committed by GitHub
parent 14385c80fc
commit 7227d06156
6 changed files with 449 additions and 28 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import deque
from dataclasses import dataclass
import numpy as np
import torch
@@ -19,6 +20,15 @@ from vllm.v1.kv_offload.worker.worker import (
logger = init_logger(__name__)
@dataclass
class Transfer:
job_id: int
stream: torch.cuda.Stream
start_event: torch.Event
end_event: torch.Event
num_bytes: int
def expand_block_ids(
block_ids: np.ndarray,
block_size_factor: int,
@@ -92,14 +102,15 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
tensor.element_size() * tensor.stride(0) * min_block_size_factor
for tensor in src_tensors
]
self.total_block_size_in_bytes = sum(self.block_size_in_bytes)
assert len(src_tensors) > 0
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
self.transfer_type = ("GPU", "CPU") if self.gpu_to_cpu else ("CPU", "GPU")
# job_id -> event
self._transfer_events: dict[int, torch.Event] = {}
# queue of transfers (job_id, stream, event)
self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque()
self._transfers: deque[Transfer] = deque()
# list of CUDA streams available for re-use
self._stream_pool: list[torch.cuda.Stream] = []
# list of CUDA events available for re-use
@@ -132,16 +143,27 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
src_to_dst_tensor = torch.from_numpy(src_to_dst)
stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream()
event = self._event_pool.pop() if self._event_pool else torch.Event()
start_event = (
self._event_pool.pop()
if self._event_pool
else torch.Event(enable_timing=True)
)
end_event = (
self._event_pool.pop()
if self._event_pool
else torch.Event(enable_timing=True)
)
if self.gpu_to_cpu:
# wait for model computation to finish before offloading
stream.wait_stream(torch.cuda.current_stream())
if self._transfers:
_, _, last_event = self._transfers[-1]
last_transfer: Transfer = self._transfers[-1]
last_event = last_transfer.end_event
# assure job will start only after the previous one completes
stream.wait_event(last_event)
with torch.cuda.stream(stream):
start_event.record(stream)
for src_tensor, dst_tensor, block_size_in_bytes in zip(
self.src_tensors,
self.dst_tensors,
@@ -153,22 +175,42 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
block_size_in_bytes,
src_to_dst_tensor,
)
event.record(stream)
end_event.record(stream)
self._transfer_events[job_id] = event
self._transfers.append((job_id, stream, event))
self._transfer_events[job_id] = end_event
self._transfers.append(
Transfer(
job_id=job_id,
stream=stream,
start_event=start_event,
end_event=end_event,
num_bytes=dst_sub_block_count * self.total_block_size_in_bytes,
)
)
# success
return True
def get_finished(self) -> list[TransferResult]:
results: list[TransferResult] = []
while self._transfers and self._transfers[0][2].query():
job_id, stream, event = self._transfers.popleft()
results.append((job_id, True))
self._stream_pool.append(stream)
self._event_pool.append(event)
del self._transfer_events[job_id]
while self._transfers and self._transfers[0].end_event.query():
transfer = self._transfers.popleft()
transfer_time = (
transfer.start_event.elapsed_time(transfer.end_event) * 1e-3
) # elapsed_time is in miliseconds
result = TransferResult(
job_id=transfer.job_id,
success=True,
transfer_size=transfer.num_bytes,
transfer_time=transfer_time,
transfer_type=self.transfer_type,
)
results.append(result)
self._stream_pool.append(transfer.stream)
self._event_pool.append(transfer.end_event)
self._event_pool.append(transfer.start_event)
del self._transfer_events[transfer.job_id]
return results
def wait(self, job_ids: set[int]):