[Hybrid]: Decouple Kernel Block Size from KV Page Size (#24486)
Signed-off-by: lizhiyuan <uniartisan2017@gmail.com> Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
@@ -68,6 +68,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
pin_memory=runner.pin_memory,
|
||||
vocab_size=runner.model_config.get_vocab_size(),
|
||||
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
|
||||
kernel_block_sizes=[
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
|
||||
],
|
||||
)
|
||||
runner.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
@@ -817,42 +820,231 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
|
||||
|
||||
# assert we are using FlashInfer
|
||||
assert attn_shape[0] == num_blocks
|
||||
assert attn_shape[0] % num_blocks == 0
|
||||
block_split_ratio = attn_shape[0] // num_blocks
|
||||
|
||||
# use small blocks for testing to avoid memory issues
|
||||
test_block_size = min(2, len(blocks0), len(blocks1))
|
||||
|
||||
# use non-overlapping blocks to avoid data contamination
|
||||
# Split kernel blocks: first half for attention, second half for mamba
|
||||
mid_point = num_blocks // 2
|
||||
|
||||
# attention uses kernel blocks from first half (mapped to logical blocks)
|
||||
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
|
||||
|
||||
# mamba uses kernel blocks from second half
|
||||
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
|
||||
|
||||
# create small constant tensors for testing with corrected shapes
|
||||
# attention: [block_size, ...] starting from dimension 2
|
||||
attn_constant_shape = attn_shape[2:]
|
||||
conv_constant_shape = conv_shape[1:]
|
||||
ssm_constant_shape = ssm_shape[1:]
|
||||
|
||||
attn_blocks_constant = torch.full(
|
||||
(len(blocks0), *attn_shape[1:]), device=DEVICE, fill_value=3.33
|
||||
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
|
||||
)
|
||||
conv_blocks_constant = torch.full(
|
||||
(len(blocks1), *conv_shape[1:]), device=DEVICE, fill_value=6.66
|
||||
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
|
||||
)
|
||||
ssm_blocks_constant = torch.full(
|
||||
(len(blocks1), *ssm_shape[1:]), device=DEVICE, fill_value=9.99
|
||||
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
|
||||
)
|
||||
|
||||
# fill all attention blocks with constant
|
||||
for layer in [layer_0, layer_1]:
|
||||
vllm_ctx[layer].kv_cache[0][blocks0, :] = (
|
||||
attn_blocks_constant.detach().clone()
|
||||
)
|
||||
# Fill attention blocks with constants using kv block indices
|
||||
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
|
||||
|
||||
# fill all mamba blocks with constant
|
||||
for layer in [layer_0, layer_1]:
|
||||
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
|
||||
|
||||
# fill mamba blocks with constants using kernel block indices
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
vllm_ctx[layer].kv_cache[0][0][blocks1, :] = (
|
||||
conv_blocks_constant.detach().clone()
|
||||
)
|
||||
vllm_ctx[layer].kv_cache[0][1][blocks1, :] = (
|
||||
ssm_blocks_constant.detach().clone()
|
||||
)
|
||||
# mamba: kv_cache[0][component][kernel_block_idx, ...]
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
|
||||
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
|
||||
|
||||
# verify attention and mamba contents are correct
|
||||
for layer in [layer_0, layer_1]:
|
||||
assert torch.equal(
|
||||
vllm_ctx[layer].kv_cache[0][blocks0, :], attn_blocks_constant
|
||||
)
|
||||
for i, kernel_block in enumerate(kernel_blocks_for_attention):
|
||||
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
|
||||
expected = attn_blocks_constant[i]
|
||||
|
||||
# Check K and V separately
|
||||
assert torch.equal(actual_kv[0], expected)
|
||||
assert torch.equal(actual_kv[1], expected)
|
||||
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
assert torch.equal(
|
||||
vllm_ctx[layer].kv_cache[0][0][blocks1, :], conv_blocks_constant
|
||||
)
|
||||
assert torch.equal(
|
||||
vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant
|
||||
)
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
for i, kv_block in enumerate(kv_blocks_for_mamba):
|
||||
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
|
||||
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
|
||||
expected_conv = conv_blocks_constant[i]
|
||||
expected_ssm = ssm_blocks_constant[i]
|
||||
assert torch.equal(actual_conv, expected_conv)
|
||||
assert torch.equal(actual_ssm, expected_ssm)
|
||||
|
||||
|
||||
def test_hybrid_block_table_initialization():
|
||||
"""Test hybrid block table with different kernel and kvcache_manager block
|
||||
sizes."""
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
# Test configuration: kvcache_manager block size = 32,
|
||||
# kernel block size = 16
|
||||
block_size = 32
|
||||
kernel_block_sizes = [16]
|
||||
max_num_reqs = 10
|
||||
max_num_blocks_per_req = 20
|
||||
max_num_batched_tokens = 512
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=False,
|
||||
device=torch.device(DEVICE),
|
||||
kernel_block_size=kernel_block_sizes[0],
|
||||
)
|
||||
|
||||
# Verify hybrid block configuration
|
||||
assert block_table.use_hybrid_blocks is True
|
||||
assert block_table.block_size == kernel_block_sizes[0]
|
||||
assert block_table.blocks_per_kv_block == (
|
||||
block_size // kernel_block_sizes[0]
|
||||
) # Changed to use first element
|
||||
|
||||
# Test block table conversion logic
|
||||
# One kvcache_manager block should map to multiple kernel blocks
|
||||
kvcache_manager_blocks = [0, 1, 2]
|
||||
|
||||
# Verify that kvcache_manager blocks can be converted to kernel blocks
|
||||
# and that block table operations work correctly.
|
||||
req_index = 0
|
||||
block_table.append_row(kvcache_manager_blocks, req_index)
|
||||
# Get expected kernel blocks from the implementation for verification.
|
||||
expected_kernel_blocks = block_table._map_to_kernel_blocks(
|
||||
np.array(kvcache_manager_blocks)
|
||||
)
|
||||
# Verify block table state
|
||||
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)
|
||||
assert np.array_equal(
|
||||
block_table.block_table.np[req_index, : len(expected_kernel_blocks)],
|
||||
expected_kernel_blocks,
|
||||
)
|
||||
|
||||
|
||||
def test_input_batch_with_kernel_block_sizes():
|
||||
"""Test InputBatch initialization with kernel_block_sizes parameter."""
|
||||
max_num_reqs = 10
|
||||
max_model_len = 512
|
||||
max_num_batched_tokens = 512
|
||||
device = torch.device(DEVICE)
|
||||
pin_memory = False
|
||||
vocab_size = 50272
|
||||
|
||||
# Test with different kernel block sizes
|
||||
block_sizes = [32, 64]
|
||||
kernel_block_sizes = [16, 32]
|
||||
|
||||
input_batch = InputBatch(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
vocab_size=vocab_size,
|
||||
block_sizes=block_sizes,
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
)
|
||||
|
||||
# Verify that block tables were created with kernel block sizes
|
||||
assert len(input_batch.block_table.block_tables) == len(block_sizes)
|
||||
|
||||
for i, (kv_size, kernel_size) in enumerate(zip(block_sizes, kernel_block_sizes)):
|
||||
block_table = input_batch.block_table.block_tables[i]
|
||||
if kv_size != kernel_size:
|
||||
assert block_table.use_hybrid_blocks is True
|
||||
assert block_table.block_size == kernel_size
|
||||
else:
|
||||
assert block_table.use_hybrid_blocks is False
|
||||
assert block_table.block_size == kernel_size
|
||||
|
||||
|
||||
def test_hybrid_cache_integration(model_runner, dist_init):
|
||||
"""Test hybrid cache architecture integration with GPUModelRunner."""
|
||||
# Create a new model runner with hybrid cache configuration
|
||||
vllm_config = get_vllm_config()
|
||||
|
||||
# Configure hybrid cache with different kvcache_manager block size
|
||||
vllm_config.cache_config.block_size = 32
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
|
||||
num_heads, head_size, 0.1
|
||||
)
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
|
||||
# Initialize KV cache with configuration
|
||||
attn_spec = FullAttentionSpec(
|
||||
block_size=16, # Use kernel block size directly
|
||||
num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
)
|
||||
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=NUM_BLOCKS,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=tensor_size, shared_by=["layer.0"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
|
||||
],
|
||||
)
|
||||
runner.kv_cache_config = kv_cache_config
|
||||
|
||||
# Initialize input batch with kernel block sizes
|
||||
runner.input_batch = InputBatch(
|
||||
max_num_reqs=runner.max_num_reqs,
|
||||
max_model_len=runner.max_model_len,
|
||||
max_num_batched_tokens=runner.max_num_tokens,
|
||||
device=runner.device,
|
||||
pin_memory=runner.pin_memory,
|
||||
vocab_size=runner.model_config.get_vocab_size(),
|
||||
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
|
||||
kernel_block_sizes=[16],
|
||||
) # Use kernel block size
|
||||
|
||||
runner.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
# Verify hybrid block table configuration
|
||||
block_table = runner.input_batch.block_table.block_tables[0]
|
||||
assert block_table.block_size == (
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
|
||||
)
|
||||
|
||||
# Test request processing with hybrid blocks
|
||||
req_id = "hybrid_req_0"
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
# Update states should work with hybrid blocks
|
||||
runner._update_states(scheduler_output)
|
||||
assert _is_req_scheduled(runner, req_id)
|
||||
assert _is_req_state_block_table_match(runner, req_id)
|
||||
|
||||
Reference in New Issue
Block a user