OffloadingConnector: Support kernel_block_size != block_size (#30692)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-01-22 14:30:04 +02:00
committed by GitHub
parent 841d53aaa8
commit 421012b63a
7 changed files with 158 additions and 116 deletions

View File

@@ -405,19 +405,41 @@ def test_swap_blocks(
# Call the swap_blocks kernel.
do_opcheck = head_size == HEAD_SIZES[0]
src_cache = src_key_caches[0]
block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
opcheck(
torch.ops._C_cache_ops.swap_blocks,
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
(
src_key_caches[0],
dist_key_caches[0],
block_size_in_bytes,
block_mapping_tensor,
),
cond=do_opcheck,
)
opcheck(
torch.ops._C_cache_ops.swap_blocks,
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
(
src_value_caches[0],
dist_value_caches[0],
block_size_in_bytes,
block_mapping_tensor,
),
cond=do_opcheck,
)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
ops.swap_blocks(
src_key_caches[0],
dist_key_caches[0],
block_size_in_bytes,
block_mapping_tensor,
)
ops.swap_blocks(
src_value_caches[0],
dist_value_caches[0],
block_size_in_bytes,
block_mapping_tensor,
)
for src, dst in block_mapping:
torch.testing.assert_close(
@@ -723,13 +745,14 @@ def test_swap_blocks_mla(
block_mapping, dtype=torch.int64, device="cpu"
).view(-1, 2)
block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
opcheck(
torch.ops._C_cache_ops.swap_blocks,
(src_cache, dst_cache, block_mapping_tensor),
(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
ops.swap_blocks(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor)
for src, dst in block_mapping:
torch.testing.assert_close(