[kv_offload+HMA][2/N]: Support multiple KV groups in GPULoadStoreSpec (#36642)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-03-18 19:26:40 +02:00
committed by GitHub
parent 39bfb57b7c
commit 5dd8df0701
3 changed files with 43 additions and 10 deletions

View File

@@ -135,19 +135,19 @@ def test_transfer(
# set transfer direction
if gpu_to_cpu:
handler = handlers.gpu_to_cpu_handler
src_spec_class = GPULoadStoreSpec
dst_spec_class = CPULoadStoreSpec
src_blocks = gpu_blocks
dst_blocks = cpu_blocks
src_spec = GPULoadStoreSpec(src_blocks, group_sizes=(len(src_blocks),))
dst_spec = CPULoadStoreSpec(dst_blocks)
src_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size
dst_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size
dst_size_in_kernel_blocks = num_cpu_blocks * kernel_blocks_per_cpu_block
else:
handler = handlers.cpu_to_gpu_handler
src_spec_class = CPULoadStoreSpec
dst_spec_class = GPULoadStoreSpec
src_blocks = cpu_blocks
dst_blocks = gpu_blocks
src_spec = CPULoadStoreSpec(src_blocks)
dst_spec = GPULoadStoreSpec(dst_blocks, group_sizes=(len(dst_blocks),))
src_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size
dst_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size
dst_size_in_kernel_blocks = num_gpu_blocks * kernel_blocks_per_gpu_block
@@ -159,10 +159,6 @@ def test_transfer(
):
dst_to_src[dst_block] = src_block
# build transfer specs
src_spec = src_spec_class(src_blocks)
dst_spec = dst_spec_class(dst_blocks)
# clone src and dst tensors before transfer
orig_src_caches = [x.clone() for x in handler.src_tensors]
orig_dst_caches = [x.clone() for x in handler.dst_tensors]

View File

@@ -173,7 +173,11 @@ class OffloadingConnectorScheduler:
)
src_spec = self.manager.prepare_load(block_hashes)
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
dst_spec = GPULoadStoreSpec(
block_ids[num_computed_gpu_blocks:],
group_sizes=(num_pending_gpu_blocks,),
block_indices=(num_computed_gpu_blocks,),
)
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=num_blocks
@@ -246,7 +250,9 @@ class OffloadingConnectorScheduler:
gpu_block_idx = offloaded_block_idx * self.block_size_factor
for i in range(self.block_size_factor):
src_block_ids.append(block_ids[gpu_block_idx + i])
src_spec = GPULoadStoreSpec(src_block_ids)
src_spec = GPULoadStoreSpec(
src_block_ids, group_sizes=(len(src_block_ids),)
)
reqs_to_store[req_id] = (src_spec, dst_spec)
self._reqs_being_stored[req_id] |= block_hashes_to_store

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC
from collections.abc import Sequence
import numpy as np
@@ -22,8 +23,38 @@ class BlockIDsLoadStoreSpec(LoadStoreSpec, ABC):
class GPULoadStoreSpec(BlockIDsLoadStoreSpec):
"""
Spec for loading/storing a KV block to GPU memory.
If there are multiple KV groups, the blocks are expected to be
ordered by the group index.
In that case, group_sizes[i] determines the number of blocks
per the i-th KV group, and thus sum(group_sizes) == len(block_ids).
group_sizes=None indicates a single KV group.
If block_indices is given, each group (determined by group_sizes) of block IDs
will correspond to logically contiguous blocks, e.g. blocks 5-10 of a some request.
block_indices[i] will represent the block index of the first block in group #i.
Thus, len(block_indices) == len(group_sizes) = number of KV cache groups.
This information is required in order to support loading from offloaded blocks
which are larger than GPU blocks.
In such cases, the first GPU block per each group may be unaligned to the offloaded
block size, and so knowing block_indices[i] allows the worker to correctly
skip part of the first matching offloaded block.
Offloading from GPU is always aligned to offloaded block size, and so
block_indices will only be set by the offloading connector when loading into GPU.
"""
def __init__(
self,
block_ids: list[int],
group_sizes: Sequence[int],
block_indices: Sequence[int] | None = None,
):
super().__init__(block_ids)
assert sum(group_sizes) == len(block_ids)
assert block_indices is None or len(block_indices) == len(group_sizes)
self.group_sizes: Sequence[int] = group_sizes
self.block_indices: Sequence[int] | None = block_indices
@staticmethod
def medium() -> str:
return "GPU"