OffloadingConnector: Support kernel_block_size != block_size (#30692)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -405,19 +405,41 @@ def test_swap_blocks(
|
||||
|
||||
# Call the swap_blocks kernel.
|
||||
do_opcheck = head_size == HEAD_SIZES[0]
|
||||
src_cache = src_key_caches[0]
|
||||
block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
|
||||
(
|
||||
src_key_caches[0],
|
||||
dist_key_caches[0],
|
||||
block_size_in_bytes,
|
||||
block_mapping_tensor,
|
||||
),
|
||||
cond=do_opcheck,
|
||||
)
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
|
||||
(
|
||||
src_value_caches[0],
|
||||
dist_value_caches[0],
|
||||
block_size_in_bytes,
|
||||
block_mapping_tensor,
|
||||
),
|
||||
cond=do_opcheck,
|
||||
)
|
||||
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
|
||||
ops.swap_blocks(
|
||||
src_key_caches[0],
|
||||
dist_key_caches[0],
|
||||
block_size_in_bytes,
|
||||
block_mapping_tensor,
|
||||
)
|
||||
ops.swap_blocks(
|
||||
src_value_caches[0],
|
||||
dist_value_caches[0],
|
||||
block_size_in_bytes,
|
||||
block_mapping_tensor,
|
||||
)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
torch.testing.assert_close(
|
||||
@@ -723,13 +745,14 @@ def test_swap_blocks_mla(
|
||||
block_mapping, dtype=torch.int64, device="cpu"
|
||||
).view(-1, 2)
|
||||
|
||||
block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_cache, dst_cache, block_mapping_tensor),
|
||||
(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
|
||||
ops.swap_blocks(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
torch.testing.assert_close(
|
||||
|
||||
Reference in New Issue
Block a user