[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)

This commit is contained in:
youkaichao
2024-05-08 12:07:05 -07:00
committed by GitHub
parent ad932a221d
commit 20cfcdec99
21 changed files with 137 additions and 109 deletions

View File

@@ -315,7 +315,10 @@ def test_swap_blocks(
else:
dst_blocks = random.sample(range(num_blocks), num_mappings)
block_mapping = dict(zip(src_blocks, dst_blocks))
block_mapping = list(zip(src_blocks, dst_blocks))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device="cpu").view(-1, 2)
# Create the KV caches on the first device.
src_key_caches, src_value_caches = kv_cache_factory(
@@ -331,10 +334,12 @@ def test_swap_blocks(
src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel.
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
block_mapping_tensor)
for src, dst in block_mapping.items():
for src, dst in block_mapping:
assert torch.allclose(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu())
assert torch.allclose(src_value_caches_clone[src].cpu(),