CPU KV Offloading: Use more CUDA streams (#29013)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -9,7 +9,7 @@ import torch
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler
|
||||
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers
|
||||
|
||||
BACKENDS_TO_TEST = [FlashAttentionBackend]
|
||||
|
||||
@@ -82,7 +82,7 @@ def test_transfer(
|
||||
|
||||
# create handler
|
||||
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
|
||||
handler = CpuGpuOffloadingHandler(
|
||||
handlers = CpuGpuOffloadingHandlers(
|
||||
attn_backends=attn_backends,
|
||||
gpu_block_size=gpu_block_size,
|
||||
cpu_block_size=cpu_block_size,
|
||||
@@ -112,8 +112,7 @@ def test_transfer(
|
||||
|
||||
# set transfer direction
|
||||
if gpu_to_cpu:
|
||||
src_kv_caches = handler.gpu_tensors
|
||||
dst_kv_caches = handler.cpu_tensors
|
||||
handler = handlers.gpu_to_cpu_handler
|
||||
src_spec_class = GPULoadStoreSpec
|
||||
dst_spec_class = CPULoadStoreSpec
|
||||
src_blocks = gpu_blocks
|
||||
@@ -122,8 +121,7 @@ def test_transfer(
|
||||
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
|
||||
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
|
||||
else:
|
||||
src_kv_caches = handler.cpu_tensors
|
||||
dst_kv_caches = handler.gpu_tensors
|
||||
handler = handlers.cpu_to_gpu_handler
|
||||
src_spec_class = CPULoadStoreSpec
|
||||
dst_spec_class = GPULoadStoreSpec
|
||||
src_blocks = cpu_blocks
|
||||
@@ -144,12 +142,12 @@ def test_transfer(
|
||||
dst_spec = dst_spec_class(dst_blocks)
|
||||
|
||||
# clone src and dst tensors before transfer
|
||||
orig_src_caches = [x.clone() for x in src_kv_caches]
|
||||
orig_dst_caches = [x.clone() for x in dst_kv_caches]
|
||||
orig_src_caches = [x.clone() for x in handler.src_tensors]
|
||||
orig_dst_caches = [x.clone() for x in handler.dst_tensors]
|
||||
|
||||
# call transfer function
|
||||
assert handler.transfer_async(1, (src_spec, dst_spec))
|
||||
assert set(handler.transfer_events.keys()) == {1}
|
||||
assert set({x[0] for x in handler._transfers}) == {1}
|
||||
|
||||
# wait for transfer to complete
|
||||
end_time = time.time() + 10
|
||||
@@ -161,15 +159,15 @@ def test_transfer(
|
||||
time.sleep(0.1)
|
||||
|
||||
# verify src tensors did not change
|
||||
for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches):
|
||||
for orig_tensor, tensor in zip(orig_src_caches, handler.src_tensors):
|
||||
assert torch.equal(orig_tensor, tensor)
|
||||
|
||||
# verify dst tensors
|
||||
for dst_block in range(dst_size_in_gpu_blocks):
|
||||
src_block_candidate = dst_to_src.get(dst_block)
|
||||
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
|
||||
src_kv_caches,
|
||||
dst_kv_caches,
|
||||
handler.src_tensors,
|
||||
handler.dst_tensors,
|
||||
orig_dst_caches,
|
||||
handler.kv_dim_before_num_blocks,
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user