Optimize data movement (#20)

This commit is contained in:
Woosuk Kwon
2023-04-02 00:30:17 -07:00
committed by GitHub
parent 1f01a18d39
commit 897cb2ae28
17 changed files with 275 additions and 135 deletions

View File

@@ -17,10 +17,10 @@ def test_reshape_and_cache(
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
kv_shape = (num_tokens, num_heads, head_size)
key = torch.randn(size=kv_shape, dtype=dtype, device='cuda')
value = torch.randn(size=kv_shape, dtype=dtype, device='cuda')
qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
_, key, value = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
@@ -35,7 +35,7 @@ def test_reshape_and_cache(
for i in range(num_tokens):
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
block_idx = slot_mapping[i] // block_size
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
block_offset = slot_mapping[i] % block_size
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i]