[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)
This commit is contained in:
@@ -315,7 +315,10 @@ def test_swap_blocks(
|
||||
else:
|
||||
dst_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
|
||||
block_mapping = dict(zip(src_blocks, dst_blocks))
|
||||
block_mapping = list(zip(src_blocks, dst_blocks))
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu").view(-1, 2)
|
||||
|
||||
# Create the KV caches on the first device.
|
||||
src_key_caches, src_value_caches = kv_cache_factory(
|
||||
@@ -331,10 +334,12 @@ def test_swap_blocks(
|
||||
src_value_caches_clone = src_value_caches[0].clone()
|
||||
|
||||
# Call the swap_blocks kernel.
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
|
||||
block_mapping_tensor)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
|
||||
block_mapping_tensor)
|
||||
|
||||
for src, dst in block_mapping.items():
|
||||
for src, dst in block_mapping:
|
||||
assert torch.allclose(src_key_caches_clone[src].cpu(),
|
||||
dist_key_caches[0][dst].cpu())
|
||||
assert torch.allclose(src_value_caches_clone[src].cpu(),
|
||||
|
||||
Reference in New Issue
Block a user