[Core] Cross-attention KV caching and memory-management (towards eventual encoder/decoder model support) (#4837)
This commit is contained in:
@@ -6,13 +6,15 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
|
||||
STR_NOT_IMPL_ENC_DEC_SWA)
|
||||
from vllm.core.block_manager_v1 import (BlockSpaceManagerV1,
|
||||
UncachedBlockAllocator)
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
from .utils import create_dummy_prompt
|
||||
from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder
|
||||
|
||||
|
||||
def test_block_allocator_allocate():
|
||||
@@ -73,7 +75,7 @@ def test_allocate():
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
for i in range(num_gpu_blocks):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
assert block_manager.can_allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||
|
||||
@@ -85,11 +87,107 @@ def test_allocate():
|
||||
watermark=1 / num_gpu_blocks)
|
||||
for i in range(num_gpu_blocks - 1):
|
||||
_, seq_group = create_dummy_prompt(str(i), block_size)
|
||||
assert block_manager.can_allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||
|
||||
|
||||
def test_allocate_encoder_decoder():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_req_per_seq_group = 2
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
for i in range(num_gpu_blocks // block_req_per_seq_group):
|
||||
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||
str(i),
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
# Use watermark to reserve one gpu block.
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=1 / num_gpu_blocks)
|
||||
for i in range((num_gpu_blocks - 1) // block_req_per_seq_group):
|
||||
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||
str(i),
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
assert block_manager.can_allocate(seq_group) == AllocStatus.OK
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.can_allocate(seq_group) != AllocStatus.OK
|
||||
|
||||
|
||||
def test_allocate_encoder_decoder_fails_with_swa():
|
||||
# SWA short for sliding window attention
|
||||
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0,
|
||||
sliding_window=5) # swa
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||
"0",
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
|
||||
# Assert that can_allocate() fails due to SWA
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
block_manager.can_allocate(seq_group)
|
||||
|
||||
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
|
||||
|
||||
# Assert that allocate() fails due to SWA
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA
|
||||
|
||||
|
||||
def test_allocate_encoder_decoder_fails_with_prefix_caching():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0,
|
||||
enable_caching=True) # Prefix cache
|
||||
|
||||
# Allocate same sequence group to all available gpu blocks.
|
||||
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||
"0",
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
|
||||
# Assert that can_allocate() fails due to prefix caching
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
block_manager.can_allocate(seq_group)
|
||||
|
||||
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
|
||||
|
||||
# Assert that allocate() fails due to prefix caching
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE
|
||||
|
||||
|
||||
def test_append_slot_single_seq():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
@@ -244,6 +342,62 @@ def test_swap():
|
||||
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
|
||||
|
||||
|
||||
def test_swap_encoder_decoder():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
decoder_prompt, encoder_prompt, seq_group = \
|
||||
create_dummy_prompt_encoder_decoder(
|
||||
"1",
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
decoder_prompt.status = SequenceStatus.WAITING
|
||||
encoder_prompt.status = SequenceStatus.WAITING
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Emulate a forward pass by appending a single token.
|
||||
# The block manager then knows how many unprocessed
|
||||
# tokens will be written in the next forward pass.
|
||||
token_id = 0
|
||||
decoder_prompt.status = SequenceStatus.RUNNING
|
||||
decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
|
||||
# Swap encoder/decoder seq group from GPU -> CPU.
|
||||
decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt)
|
||||
cross_gpu_blocks = block_manager.get_cross_block_table(seq_group)
|
||||
gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks
|
||||
assert block_manager.can_swap_out(seq_group)
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_out(seq_group)
|
||||
assert [x[0] for x in mapping] == gpu_blocks
|
||||
#assert list(mapping.keys()) == gpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
|
||||
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
|
||||
decoder_prompt.status = SequenceStatus.SWAPPED
|
||||
|
||||
# Swap encoder/decoder seq group from CPU -> GPU.
|
||||
decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt)
|
||||
cross_cpu_blocks = block_manager.get_cross_block_table(seq_group)
|
||||
cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks
|
||||
assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_in(seq_group)
|
||||
assert [x[0] for x in mapping] == cpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
|
||||
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
|
||||
|
||||
|
||||
def test_free():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
@@ -268,6 +422,41 @@ def test_free():
|
||||
block_manager.get_block_table(prompt)
|
||||
|
||||
|
||||
def test_free_encoder_decoder():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
decoder_prompt, encoder_prompt, seq_group = \
|
||||
create_dummy_prompt_encoder_decoder(
|
||||
"1",
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Free allocated seq.
|
||||
decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt))
|
||||
encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group))
|
||||
prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
block_manager.free(decoder_prompt)
|
||||
block_manager.free_cross(seq_group)
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert after_blocks == before_blocks + prompt_blocks
|
||||
|
||||
# Block table for freed encoder & decoder seq's are deleted.
|
||||
with pytest.raises(KeyError):
|
||||
block_manager.get_block_table(decoder_prompt)
|
||||
|
||||
# Block table for freed encoder & decoder seq's are deleted.
|
||||
with pytest.raises(KeyError):
|
||||
block_manager.get_block_table(encoder_prompt)
|
||||
|
||||
|
||||
def test_reset():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
@@ -289,6 +478,31 @@ def test_reset():
|
||||
assert block_manager.get_num_free_gpu_blocks() == original_blocks
|
||||
|
||||
|
||||
def test_reset_encoder_decoder():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
num_gpu_blocks = 4
|
||||
block_req_per_seq_group = 2
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_cpu_blocks,
|
||||
num_gpu_blocks,
|
||||
watermark=0)
|
||||
|
||||
# Allocate same seq group on all available gpu blocks.
|
||||
original_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
for i in range(num_gpu_blocks // block_req_per_seq_group):
|
||||
_, _, seq_group = create_dummy_prompt_encoder_decoder(
|
||||
f"{i}",
|
||||
decoder_prompt_length=block_size,
|
||||
encoder_prompt_length=block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
assert block_manager.get_num_free_gpu_blocks() == 0
|
||||
|
||||
# Resetting block manager frees all allocated blocks.
|
||||
block_manager.reset()
|
||||
assert block_manager.get_num_free_gpu_blocks() == original_blocks
|
||||
|
||||
|
||||
def test_sliding_window_multi_seq():
|
||||
"""
|
||||
Tests that memory allocation and deallocation is handled
|
||||
|
||||
Reference in New Issue
Block a user