[Quality] Add code formatter and linter (#326)
This commit is contained in:
@@ -26,8 +26,9 @@ def run_copy_blocks(
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_caches = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.randn(
|
||||
size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
key_cache = torch.randn(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
key_caches.append(key_cache)
|
||||
cloned_key_caches = []
|
||||
for key_cache in key_caches:
|
||||
@@ -36,8 +37,9 @@ def run_copy_blocks(
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_caches = []
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
value_cache = torch.randn(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
value_caches.append(value_cache)
|
||||
cloned_value_caches = []
|
||||
for value_cache in value_caches:
|
||||
@@ -49,15 +51,18 @@ def run_copy_blocks(
|
||||
# Reference implementation.
|
||||
for src, dsts in block_mapping.items():
|
||||
for dst in dsts:
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
for key_cache, cloned_key_cache in zip(key_caches,
|
||||
cloned_key_caches):
|
||||
cloned_key_cache[dst] = cloned_key_cache[src]
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
for value_cache, cloned_value_cache in zip(value_caches,
|
||||
cloned_value_caches):
|
||||
cloned_value_cache[dst] = cloned_value_cache[src]
|
||||
|
||||
# Compare the results.
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
for value_cache, cloned_value_cache in zip(value_caches,
|
||||
cloned_value_caches):
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@@ -74,8 +79,12 @@ def run_reshape_and_cache(
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
@@ -84,15 +93,19 @@ def run_reshape_and_cache(
|
||||
cloned_key_cache = key_cache.clone()
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
value_cache = torch.randn(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping)
|
||||
|
||||
for i in range(num_tokens):
|
||||
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
block_idx = torch.div(slot_mapping[i],
|
||||
block_size,
|
||||
rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||
@@ -114,8 +127,12 @@ def run_gather_cached_kv(
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
qkv_clone = qkv.clone()
|
||||
@@ -126,15 +143,20 @@ def run_gather_cached_kv(
|
||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
value_cache = torch.randn(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
|
||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)
|
||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
|
||||
slot_mapping)
|
||||
|
||||
# Reference implementation.
|
||||
for i in range(num_tokens):
|
||||
reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
reshaped_key = cloned_key.reshape(num_tokens, num_heads,
|
||||
head_size // x, x)
|
||||
block_idx = torch.div(slot_mapping[i],
|
||||
block_size,
|
||||
rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
|
||||
cloned_value[i] = value_cache[block_idx, :, :, block_offset]
|
||||
@@ -145,20 +167,30 @@ def run_gather_cached_kv(
|
||||
|
||||
def test_copy_blocks() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_copy_blocks(
|
||||
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
|
||||
block_size=8, num_blocks=1024, dtype=dtype)
|
||||
run_copy_blocks(num_mappings=23,
|
||||
num_layers=7,
|
||||
num_heads=17,
|
||||
head_size=16,
|
||||
block_size=8,
|
||||
num_blocks=1024,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def test_reshape_and_cache() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_reshape_and_cache(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=dtype)
|
||||
run_reshape_and_cache(num_tokens=3,
|
||||
num_heads=2,
|
||||
head_size=16,
|
||||
block_size=8,
|
||||
num_blocks=2,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def test_gather_cached_kv() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_gather_cached_kv(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=dtype)
|
||||
run_gather_cached_kv(num_tokens=3,
|
||||
num_heads=2,
|
||||
head_size=16,
|
||||
block_size=8,
|
||||
num_blocks=2,
|
||||
dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user