Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -22,7 +22,7 @@ NUM_HEADS = [8]
|
||||
NUM_LAYERS = [4]
|
||||
DTYPES = [torch.bfloat16]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
NUM_MAPPINGS = [3]
|
||||
|
||||
|
||||
@@ -56,35 +56,35 @@ def test_transfer(
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
# create per-layer GPU KV caches
|
||||
attn_backends_list = [
|
||||
FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend
|
||||
]
|
||||
attn_backends_list = [FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend]
|
||||
|
||||
gpu_caches = {}
|
||||
attn_backends = {}
|
||||
for i in range(num_layers):
|
||||
layer_name = f'layer {i}'
|
||||
layer_name = f"layer {i}"
|
||||
|
||||
attn_backend = attn_backends_list[i % len(attn_backends_list)]
|
||||
attn_backends[layer_name] = attn_backend
|
||||
|
||||
gpu_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_gpu_blocks, gpu_block_size, num_heads, head_size)
|
||||
gpu_caches[layer_name] = torch.rand(gpu_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
num_gpu_blocks, gpu_block_size, num_heads, head_size
|
||||
)
|
||||
gpu_caches[layer_name] = torch.rand(gpu_cache_shape, dtype=dtype, device=device)
|
||||
|
||||
# create handler
|
||||
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
|
||||
handler = CpuGpuOffloadingHandler(attn_backends=attn_backends,
|
||||
gpu_block_size=gpu_block_size,
|
||||
cpu_block_size=cpu_block_size,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
gpu_caches=gpu_caches)
|
||||
handler = CpuGpuOffloadingHandler(
|
||||
attn_backends=attn_backends,
|
||||
gpu_block_size=gpu_block_size,
|
||||
cpu_block_size=cpu_block_size,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
gpu_caches=gpu_caches,
|
||||
)
|
||||
|
||||
# select block mappings
|
||||
gpu_blocks = random.sample(range(num_gpu_blocks),
|
||||
num_mappings * gpu_blocks_per_cpu_block)
|
||||
gpu_blocks = random.sample(
|
||||
range(num_gpu_blocks), num_mappings * gpu_blocks_per_cpu_block
|
||||
)
|
||||
cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings)
|
||||
|
||||
# convert cpu blocks to gpu block size
|
||||
@@ -96,9 +96,10 @@ def test_transfer(
|
||||
|
||||
# maybe skip a GPU block to test writing to the middle of a CPU block
|
||||
if gpu_to_cpu:
|
||||
gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1:]
|
||||
gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1 :]
|
||||
cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[
|
||||
gpu_blocks_per_cpu_block - 1:]
|
||||
gpu_blocks_per_cpu_block - 1 :
|
||||
]
|
||||
|
||||
# set transfer direction
|
||||
if gpu_to_cpu:
|
||||
@@ -124,8 +125,9 @@ def test_transfer(
|
||||
|
||||
# build dst -> src mapping
|
||||
dst_to_src = {}
|
||||
for src_block, dst_block in zip(src_blocks_in_gpu_block_size,
|
||||
dst_blocks_in_gpu_block_size):
|
||||
for src_block, dst_block in zip(
|
||||
src_blocks_in_gpu_block_size, dst_blocks_in_gpu_block_size
|
||||
):
|
||||
dst_to_src[dst_block] = src_block
|
||||
|
||||
# build transfer specs
|
||||
@@ -157,8 +159,11 @@ def test_transfer(
|
||||
for dst_block in range(dst_size_in_gpu_blocks):
|
||||
src_block_candidate = dst_to_src.get(dst_block)
|
||||
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
|
||||
src_kv_caches, dst_kv_caches, orig_dst_caches,
|
||||
handler.kv_dim_before_num_blocks):
|
||||
src_kv_caches,
|
||||
dst_kv_caches,
|
||||
orig_dst_caches,
|
||||
handler.kv_dim_before_num_blocks,
|
||||
):
|
||||
if kv_dim:
|
||||
# iterate over key, value
|
||||
for i in range(2):
|
||||
@@ -166,12 +171,14 @@ def test_transfer(
|
||||
expected_value = src_cache[i][src_block_candidate]
|
||||
else:
|
||||
expected_value = orig_dst_cache[i][dst_block]
|
||||
torch.testing.assert_close(dst_cache[i][dst_block].cpu(),
|
||||
expected_value.cpu())
|
||||
torch.testing.assert_close(
|
||||
dst_cache[i][dst_block].cpu(), expected_value.cpu()
|
||||
)
|
||||
else:
|
||||
if src_block_candidate is not None:
|
||||
expected_value = src_cache[src_block_candidate]
|
||||
else:
|
||||
expected_value = orig_dst_cache[dst_block]
|
||||
torch.testing.assert_close(dst_cache[dst_block].cpu(),
|
||||
expected_value.cpu())
|
||||
torch.testing.assert_close(
|
||||
dst_cache[dst_block].cpu(), expected_value.cpu()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user