OffloadingConnector: Support kernel_block_size != block_size (#30692)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-01-22 14:30:04 +02:00
committed by GitHub
parent 841d53aaa8
commit 421012b63a
7 changed files with 158 additions and 116 deletions

View File

@@ -7,6 +7,7 @@
#include <vector>
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,

View File

@@ -25,6 +25,7 @@ typedef __hip_bfloat16 __nv_bfloat16;
#endif
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
@@ -49,10 +50,6 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());
// We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

View File

@@ -692,7 +692,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
"swap_blocks(Tensor src, Tensor! dst,"
" int block_size_in_bytes, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Reshape the key and value tensors and cache them.

View File

@@ -405,19 +405,41 @@ def test_swap_blocks(
# Call the swap_blocks kernel.
do_opcheck = head_size == HEAD_SIZES[0]
src_cache = src_key_caches[0]
block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
opcheck(
torch.ops._C_cache_ops.swap_blocks,
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
(
src_key_caches[0],
dist_key_caches[0],
block_size_in_bytes,
block_mapping_tensor,
),
cond=do_opcheck,
)
opcheck(
torch.ops._C_cache_ops.swap_blocks,
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
(
src_value_caches[0],
dist_value_caches[0],
block_size_in_bytes,
block_mapping_tensor,
),
cond=do_opcheck,
)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
ops.swap_blocks(
src_key_caches[0],
dist_key_caches[0],
block_size_in_bytes,
block_mapping_tensor,
)
ops.swap_blocks(
src_value_caches[0],
dist_value_caches[0],
block_size_in_bytes,
block_mapping_tensor,
)
for src, dst in block_mapping:
torch.testing.assert_close(
@@ -723,13 +745,14 @@ def test_swap_blocks_mla(
block_mapping, dtype=torch.int64, device="cpu"
).view(-1, 2)
block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
opcheck(
torch.ops._C_cache_ops.swap_blocks,
(src_cache, dst_cache, block_mapping_tensor),
(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
ops.swap_blocks(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor)
for src, dst in block_mapping:
torch.testing.assert_close(

View File

@@ -25,8 +25,9 @@ if not current_platform.is_rocm():
NUM_GPU_BLOCKS = [64]
NUM_CPU_BLOCKS = [256]
GPU_BLOCK_SIZES = [16]
GPU_BLOCKS_PER_CPU_BLOCK = [1, 3]
KERNEL_BLOCK_SIZES = [16]
LOGICAL_BLOCK_SIZES = [16, 32]
LOGICAL_BLOCKS_PER_CPU_BLOCK = [1, 3]
HEAD_SIZES = [64]
NUM_HEADS = [8]
NUM_LAYERS = [4]
@@ -40,8 +41,9 @@ NUM_MAPPINGS = [3]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES)
@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK)
@pytest.mark.parametrize("kernel_block_size", KERNEL_BLOCK_SIZES)
@pytest.mark.parametrize("logical_block_size", LOGICAL_BLOCK_SIZES)
@pytest.mark.parametrize("logical_blocks_per_cpu_block", LOGICAL_BLOCKS_PER_CPU_BLOCK)
@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS)
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@@ -55,8 +57,9 @@ def test_transfer(
num_mappings: int,
head_size: int,
num_heads: int,
gpu_block_size: int,
gpu_blocks_per_cpu_block: int,
kernel_block_size: int,
logical_block_size: int,
logical_blocks_per_cpu_block: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
num_layers: int,
@@ -69,6 +72,10 @@ def test_transfer(
# create per-layer GPU KV caches based on available attn_backends
attn_backends_list = BACKENDS_TO_TEST
assert logical_block_size % kernel_block_size == 0
kernel_blocks_per_gpu_block = logical_block_size // kernel_block_size
num_gpu_kernel_blocks = num_gpu_blocks * kernel_blocks_per_gpu_block
gpu_caches = {}
attn_backends = {}
for i in range(num_layers):
@@ -78,15 +85,16 @@ def test_transfer(
attn_backends[layer_name] = attn_backend
gpu_cache_shape = attn_backend.get_kv_cache_shape(
num_gpu_blocks, gpu_block_size, num_heads, head_size
num_gpu_kernel_blocks, kernel_block_size, num_heads, head_size
)
gpu_caches[layer_name] = torch.rand(gpu_cache_shape, dtype=dtype, device=device)
# create handler
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
cpu_block_size = logical_blocks_per_cpu_block * logical_block_size
kernel_blocks_per_cpu_block = cpu_block_size // kernel_block_size
handlers = CpuGpuOffloadingHandlers(
attn_backends=attn_backends,
gpu_block_size=gpu_block_size,
gpu_block_size=logical_block_size,
cpu_block_size=cpu_block_size,
num_cpu_blocks=num_cpu_blocks,
gpu_caches=gpu_caches,
@@ -94,22 +102,34 @@ def test_transfer(
# select block mappings
gpu_blocks = random.sample(
range(num_gpu_blocks), num_mappings * gpu_blocks_per_cpu_block
range(num_gpu_blocks), num_mappings * logical_blocks_per_cpu_block
)
cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings)
# convert cpu blocks to gpu block size
cpu_blocks_in_gpu_block_size = []
for cpu_block in cpu_blocks:
base_block_id = cpu_block * gpu_blocks_per_cpu_block
for i in range(gpu_blocks_per_cpu_block):
cpu_blocks_in_gpu_block_size.append(i + base_block_id)
# convert gpu blocks to kernel block size
gpu_blocks_in_kernel_block_size = []
for gpu_block in gpu_blocks:
base_block_id = gpu_block * kernel_blocks_per_gpu_block
for i in range(kernel_blocks_per_gpu_block):
gpu_blocks_in_kernel_block_size.append(i + base_block_id)
# maybe skip a GPU block to test reading from the middle of a CPU block
# convert cpu blocks to gpu block size
cpu_blocks_in_kernel_block_size = []
for cpu_block in cpu_blocks:
base_block_id = cpu_block * kernel_blocks_per_cpu_block
for i in range(kernel_blocks_per_cpu_block):
cpu_blocks_in_kernel_block_size.append(i + base_block_id)
# maybe skip some GPU block to test reading from the middle of a CPU block
if not gpu_to_cpu:
gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1 :]
cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[
gpu_blocks_per_cpu_block - 1 :
gpu_blocks_to_skip = logical_blocks_per_cpu_block - 1
gpu_blocks = gpu_blocks[gpu_blocks_to_skip:]
kernel_blocks_to_skip = gpu_blocks_to_skip * kernel_blocks_per_gpu_block
gpu_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size[
kernel_blocks_to_skip:
]
cpu_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size[
kernel_blocks_to_skip:
]
# set transfer direction
@@ -119,23 +139,23 @@ def test_transfer(
dst_spec_class = CPULoadStoreSpec
src_blocks = gpu_blocks
dst_blocks = cpu_blocks
src_blocks_in_gpu_block_size = gpu_blocks
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
src_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size
dst_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size
dst_size_in_kernel_blocks = num_cpu_blocks * kernel_blocks_per_cpu_block
else:
handler = handlers.cpu_to_gpu_handler
src_spec_class = CPULoadStoreSpec
dst_spec_class = GPULoadStoreSpec
src_blocks = cpu_blocks
dst_blocks = gpu_blocks
src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
dst_blocks_in_gpu_block_size = gpu_blocks
dst_size_in_gpu_blocks = num_gpu_blocks
src_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size
dst_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size
dst_size_in_kernel_blocks = num_gpu_blocks * kernel_blocks_per_gpu_block
# build dst -> src mapping
dst_to_src = {}
for src_block, dst_block in zip(
src_blocks_in_gpu_block_size, dst_blocks_in_gpu_block_size
src_blocks_in_kernel_block_size, dst_blocks_in_kernel_block_size
):
dst_to_src[dst_block] = src_block
@@ -165,29 +185,15 @@ def test_transfer(
assert torch.equal(orig_tensor, tensor)
# verify dst tensors
for dst_block in range(dst_size_in_gpu_blocks):
for dst_block in range(dst_size_in_kernel_blocks):
src_block_candidate = dst_to_src.get(dst_block)
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
for src_cache, dst_cache, orig_dst_cache in zip(
handler.src_tensors,
handler.dst_tensors,
orig_dst_caches,
handler.kv_dim_before_num_blocks,
):
if kv_dim:
# iterate over key, value
for i in range(2):
if src_block_candidate is not None:
expected_value = src_cache[i][src_block_candidate]
else:
expected_value = orig_dst_cache[i][dst_block]
torch.testing.assert_close(
dst_cache[i][dst_block].cpu(), expected_value.cpu()
)
if src_block_candidate is not None:
expected_value = src_cache[src_block_candidate]
else:
if src_block_candidate is not None:
expected_value = src_cache[src_block_candidate]
else:
expected_value = orig_dst_cache[dst_block]
torch.testing.assert_close(
dst_cache[dst_block].cpu(), expected_value.cpu()
)
expected_value = orig_dst_cache[dst_block]
torch.testing.assert_close(dst_cache[dst_block].cpu(), expected_value.cpu())

View File

@@ -2455,9 +2455,32 @@ def concat_and_cache_mla_rope_fused(
def swap_blocks(
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
src: torch.Tensor,
dst: torch.Tensor,
block_size_in_bytes: int,
block_mapping: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
"""
Copy specific blocks from one tensor to another.
This method assumes each of the two input tensors is composed of
consecutive contiguous blocks, of size block_size_in_bytes.
i.e. the memory layout for each tensor is:
[block0] [block1] ... [block N]
block_mapping determines the subset of blocks to copy of the source tensor,
and their matching destination block number on the destination tensor.
block_mapping is expected to be a tensor of shape (num_blocks_to_copy, 2)
where each block_mapping[i] represents a single copy operation, copying
block #block_mapping[i][0] from the source tensor
to block #block_mapping[i][1] on the destination tensor.
block_mapping should have dtype int64.
The source and the destination tensors can be either on cpu or gpu,
but not both on cpu.
the block mapping tensor must on cpu.
"""
torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping)
def convert_fp8(

View File

@@ -65,7 +65,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
self,
src_tensors: list[torch.Tensor],
dst_tensors: list[torch.Tensor],
kv_dim_before_num_blocks: list[bool],
src_block_size_factor: int,
dst_block_size_factor: int,
):
@@ -76,22 +75,23 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
src_tensors: list of KV cache tensors to copy from.
dst_tensors: list of KV cache tensors to copy to.
Order should match src_tensors.
kv_dim_before_num_blocks: list of bools, indicating
whether the respective KV cache tensor has a KV
dimension before its num_blocks dimension.
e.g. (2, num_blocks, ...)
src_block_size_factor: The number of kernel blocks
per KV block in a source tensor.
dst_block_size_factor: The number of kernel blocks
per KV block in a destination tensor.
"""
assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks)
assert len(src_tensors) == len(dst_tensors)
self.src_tensors: list[torch.Tensor] = src_tensors
self.dst_tensors: list[torch.Tensor] = dst_tensors
self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks
self.src_block_size_factor: int = src_block_size_factor
self.dst_block_size_factor: int = dst_block_size_factor
min_block_size_factor = min(src_block_size_factor, dst_block_size_factor)
self.src_block_size_factor: int = src_block_size_factor // min_block_size_factor
self.dst_block_size_factor: int = dst_block_size_factor // min_block_size_factor
self.block_size_in_bytes = [
tensor.element_size() * tensor.stride(0) * min_block_size_factor
for tensor in src_tensors
]
assert len(src_tensors) > 0
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
@@ -142,16 +142,17 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
# assure job will start only after the previous one completes
stream.wait_event(last_event)
with torch.cuda.stream(stream):
for src_tensor, dst_tensor, kv_dim in zip(
self.src_tensors, self.dst_tensors, self.kv_dim_before_num_blocks
for src_tensor, dst_tensor, block_size_in_bytes in zip(
self.src_tensors,
self.dst_tensors,
self.block_size_in_bytes,
):
if kv_dim:
src_key_cache, src_value_cache = src_tensor
dst_key_cache, dst_value_cache = dst_tensor
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
else:
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
ops.swap_blocks(
src_tensor,
dst_tensor,
block_size_in_bytes,
src_to_dst_tensor,
)
event.record(stream)
self._transfer_events[job_id] = event
@@ -188,19 +189,12 @@ class CpuGpuOffloadingHandlers:
):
assert gpu_caches
assert cpu_block_size % gpu_block_size == 0
block_size_factor = cpu_block_size // gpu_block_size
pin_memory = is_pin_memory_available()
# allocate cpu tensors
logger.info("Allocating %d CPU tensors...", len(gpu_caches))
gpu_tensors: list[torch.Tensor] = []
cpu_tensors: list[torch.Tensor] = []
kv_dim_before_num_blocks: list[bool] = []
# find kernel block size and determine layout per each gpu tensor
kernel_block_size: int | None = None
# list of (gpu_tensor, split_k_and_v)
parsed_gpu_tensors: list[tuple[torch.Tensor, bool]] = []
for layer_name, gpu_tensor in gpu_caches.items():
gpu_tensors.append(gpu_tensor)
gpu_shape = gpu_tensor.shape
attn_backend = attn_backends[layer_name]
test_shape = attn_backend.get_kv_cache_shape(
@@ -208,28 +202,20 @@ class CpuGpuOffloadingHandlers:
)
has_layers_dim = False
split_k_and_v = False
if len(gpu_shape) != len(test_shape):
# cross-layers tensor
# shape is (num_blocks, ...)
assert len(gpu_shape) == len(test_shape) + 1
num_blocks_idx = 0
has_layers_dim = True
kv_dim_before_num_blocks.append(False)
# prepend a dummy num_layers=80 to test_shape
test_shape = (80,) + test_shape
elif test_shape[0] == 1234:
# shape is (num_blocks, ...)
num_blocks_idx = 0
kv_dim_before_num_blocks.append(False)
else:
elif test_shape[0] != 1234:
# shape should be (2, num_blocks, ...)
assert test_shape[0] == 2
assert test_shape[1] == 1234
assert gpu_shape[0] == 2
num_blocks_idx = 1
kv_dim_before_num_blocks.append(True)
split_k_and_v = True
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
@@ -250,30 +236,36 @@ class CpuGpuOffloadingHandlers:
kernel_block_size = gpu_shape[block_size_idx]
assert gpu_block_size % kernel_block_size == 0
cpu_shape = list(gpu_shape)
cpu_shape[num_blocks_idx] = num_cpu_blocks * block_size_factor
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
cpu_tensors.append(
torch.zeros(
cpu_shape,
dtype=gpu_tensor.dtype,
device="cpu",
pin_memory=pin_memory,
)
)
parsed_gpu_tensors.append((gpu_tensor, split_k_and_v))
assert kernel_block_size is not None
gpu_block_size_factor = gpu_block_size // kernel_block_size
cpu_block_size_factor = cpu_block_size // kernel_block_size
gpu_block_size_factor = gpu_block_size // kernel_block_size
num_cpu_kernel_blocks = num_cpu_blocks * cpu_block_size_factor
# TODO (orozery): adapt swap_blocks to support gpu_block_size_factor
assert gpu_block_size_factor == 1
# allocate cpu tensors
pin_memory = is_pin_memory_available()
logger.info("Allocating %d CPU tensors...", len(parsed_gpu_tensors))
gpu_tensors: list[torch.Tensor] = []
cpu_tensors: list[torch.Tensor] = []
for gpu_tensor, split_k_and_v in parsed_gpu_tensors:
cpu_shape = list(gpu_tensor.shape)
cpu_shape[1 if split_k_and_v else 0] = num_cpu_kernel_blocks
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
cpu_tensor = torch.zeros(
cpu_shape,
dtype=gpu_tensor.dtype,
device="cpu",
pin_memory=pin_memory,
)
gpu_tensors.extend(gpu_tensor.unbind(0) if split_k_and_v else [gpu_tensor])
cpu_tensors.extend(cpu_tensor.unbind(0) if split_k_and_v else [cpu_tensor])
self.gpu_to_cpu_handler = SingleDirectionOffloadingHandler(
src_tensors=gpu_tensors,
dst_tensors=cpu_tensors,
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=gpu_block_size_factor,
dst_block_size_factor=cpu_block_size_factor,
)
@@ -281,7 +273,6 @@ class CpuGpuOffloadingHandlers:
self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler(
src_tensors=cpu_tensors,
dst_tensors=gpu_tensors,
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=cpu_block_size_factor,
dst_block_size_factor=gpu_block_size_factor,
)