OffloadingConnector: Support kernel_block_size != block_size (#30692)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-01-22 14:30:04 +02:00
committed by GitHub
parent 841d53aaa8
commit 421012b63a
7 changed files with 158 additions and 116 deletions

View File

@@ -2455,9 +2455,32 @@ def concat_and_cache_mla_rope_fused(
def swap_blocks(
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
src: torch.Tensor,
dst: torch.Tensor,
block_size_in_bytes: int,
block_mapping: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
"""
Copy specific blocks from one tensor to another.
This method assumes each of the two input tensors is composed of
consecutive contiguous blocks, of size block_size_in_bytes.
i.e. the memory layout for each tensor is:
[block0] [block1] ... [block N]
block_mapping determines the subset of blocks to copy of the source tensor,
and their matching destination block number on the destination tensor.
block_mapping is expected to be a tensor of shape (num_blocks_to_copy, 2)
where each block_mapping[i] represents a single copy operation, copying
block #block_mapping[i][0] from the source tensor
to block #block_mapping[i][1] on the destination tensor.
block_mapping should have dtype int64.
The source and the destination tensors can be either on cpu or gpu,
but not both on cpu.
the block mapping tensor must on cpu.
"""
torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping)
def convert_fp8(

View File

@@ -65,7 +65,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
self,
src_tensors: list[torch.Tensor],
dst_tensors: list[torch.Tensor],
kv_dim_before_num_blocks: list[bool],
src_block_size_factor: int,
dst_block_size_factor: int,
):
@@ -76,22 +75,23 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
src_tensors: list of KV cache tensors to copy from.
dst_tensors: list of KV cache tensors to copy to.
Order should match src_tensors.
kv_dim_before_num_blocks: list of bools, indicating
whether the respective KV cache tensor has a KV
dimension before its num_blocks dimension.
e.g. (2, num_blocks, ...)
src_block_size_factor: The number of kernel blocks
per KV block in a source tensor.
dst_block_size_factor: The number of kernel blocks
per KV block in a destination tensor.
"""
assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks)
assert len(src_tensors) == len(dst_tensors)
self.src_tensors: list[torch.Tensor] = src_tensors
self.dst_tensors: list[torch.Tensor] = dst_tensors
self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks
self.src_block_size_factor: int = src_block_size_factor
self.dst_block_size_factor: int = dst_block_size_factor
min_block_size_factor = min(src_block_size_factor, dst_block_size_factor)
self.src_block_size_factor: int = src_block_size_factor // min_block_size_factor
self.dst_block_size_factor: int = dst_block_size_factor // min_block_size_factor
self.block_size_in_bytes = [
tensor.element_size() * tensor.stride(0) * min_block_size_factor
for tensor in src_tensors
]
assert len(src_tensors) > 0
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
@@ -142,16 +142,17 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
# assure job will start only after the previous one completes
stream.wait_event(last_event)
with torch.cuda.stream(stream):
for src_tensor, dst_tensor, kv_dim in zip(
self.src_tensors, self.dst_tensors, self.kv_dim_before_num_blocks
for src_tensor, dst_tensor, block_size_in_bytes in zip(
self.src_tensors,
self.dst_tensors,
self.block_size_in_bytes,
):
if kv_dim:
src_key_cache, src_value_cache = src_tensor
dst_key_cache, dst_value_cache = dst_tensor
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
else:
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
ops.swap_blocks(
src_tensor,
dst_tensor,
block_size_in_bytes,
src_to_dst_tensor,
)
event.record(stream)
self._transfer_events[job_id] = event
@@ -188,19 +189,12 @@ class CpuGpuOffloadingHandlers:
):
assert gpu_caches
assert cpu_block_size % gpu_block_size == 0
block_size_factor = cpu_block_size // gpu_block_size
pin_memory = is_pin_memory_available()
# allocate cpu tensors
logger.info("Allocating %d CPU tensors...", len(gpu_caches))
gpu_tensors: list[torch.Tensor] = []
cpu_tensors: list[torch.Tensor] = []
kv_dim_before_num_blocks: list[bool] = []
# find kernel block size and determine layout per each gpu tensor
kernel_block_size: int | None = None
# list of (gpu_tensor, split_k_and_v)
parsed_gpu_tensors: list[tuple[torch.Tensor, bool]] = []
for layer_name, gpu_tensor in gpu_caches.items():
gpu_tensors.append(gpu_tensor)
gpu_shape = gpu_tensor.shape
attn_backend = attn_backends[layer_name]
test_shape = attn_backend.get_kv_cache_shape(
@@ -208,28 +202,20 @@ class CpuGpuOffloadingHandlers:
)
has_layers_dim = False
split_k_and_v = False
if len(gpu_shape) != len(test_shape):
# cross-layers tensor
# shape is (num_blocks, ...)
assert len(gpu_shape) == len(test_shape) + 1
num_blocks_idx = 0
has_layers_dim = True
kv_dim_before_num_blocks.append(False)
# prepend a dummy num_layers=80 to test_shape
test_shape = (80,) + test_shape
elif test_shape[0] == 1234:
# shape is (num_blocks, ...)
num_blocks_idx = 0
kv_dim_before_num_blocks.append(False)
else:
elif test_shape[0] != 1234:
# shape should be (2, num_blocks, ...)
assert test_shape[0] == 2
assert test_shape[1] == 1234
assert gpu_shape[0] == 2
num_blocks_idx = 1
kv_dim_before_num_blocks.append(True)
split_k_and_v = True
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
@@ -250,30 +236,36 @@ class CpuGpuOffloadingHandlers:
kernel_block_size = gpu_shape[block_size_idx]
assert gpu_block_size % kernel_block_size == 0
cpu_shape = list(gpu_shape)
cpu_shape[num_blocks_idx] = num_cpu_blocks * block_size_factor
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
cpu_tensors.append(
torch.zeros(
cpu_shape,
dtype=gpu_tensor.dtype,
device="cpu",
pin_memory=pin_memory,
)
)
parsed_gpu_tensors.append((gpu_tensor, split_k_and_v))
assert kernel_block_size is not None
gpu_block_size_factor = gpu_block_size // kernel_block_size
cpu_block_size_factor = cpu_block_size // kernel_block_size
gpu_block_size_factor = gpu_block_size // kernel_block_size
num_cpu_kernel_blocks = num_cpu_blocks * cpu_block_size_factor
# TODO (orozery): adapt swap_blocks to support gpu_block_size_factor
assert gpu_block_size_factor == 1
# allocate cpu tensors
pin_memory = is_pin_memory_available()
logger.info("Allocating %d CPU tensors...", len(parsed_gpu_tensors))
gpu_tensors: list[torch.Tensor] = []
cpu_tensors: list[torch.Tensor] = []
for gpu_tensor, split_k_and_v in parsed_gpu_tensors:
cpu_shape = list(gpu_tensor.shape)
cpu_shape[1 if split_k_and_v else 0] = num_cpu_kernel_blocks
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
cpu_tensor = torch.zeros(
cpu_shape,
dtype=gpu_tensor.dtype,
device="cpu",
pin_memory=pin_memory,
)
gpu_tensors.extend(gpu_tensor.unbind(0) if split_k_and_v else [gpu_tensor])
cpu_tensors.extend(cpu_tensor.unbind(0) if split_k_and_v else [cpu_tensor])
self.gpu_to_cpu_handler = SingleDirectionOffloadingHandler(
src_tensors=gpu_tensors,
dst_tensors=cpu_tensors,
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=gpu_block_size_factor,
dst_block_size_factor=cpu_block_size_factor,
)
@@ -281,7 +273,6 @@ class CpuGpuOffloadingHandlers:
self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler(
src_tensors=cpu_tensors,
dst_tensors=gpu_tensors,
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=cpu_block_size_factor,
dst_block_size_factor=gpu_block_size_factor,
)