[Hybrid]: Decouple Kernel Block Size from KV Page Size (#24486)

Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
Zhiyuan Li
2025-10-09 14:43:39 +08:00
committed by GitHub
parent d17f0fbf30
commit d24cf322e1
18 changed files with 573 additions and 55 deletions

View File

@@ -68,6 +68,9 @@ def initialize_kv_cache(runner: GPUModelRunner):
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
kernel_block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
)
runner.initialize_attn_backend(kv_cache_config)
@@ -817,42 +820,231 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
# assert we are using FlashInfer
assert attn_shape[0] == num_blocks
assert attn_shape[0] % num_blocks == 0
block_split_ratio = attn_shape[0] // num_blocks
# use small blocks for testing to avoid memory issues
test_block_size = min(2, len(blocks0), len(blocks1))
# use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba
mid_point = num_blocks // 2
# attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
# mamba uses kernel blocks from second half
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
# create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2
attn_constant_shape = attn_shape[2:]
conv_constant_shape = conv_shape[1:]
ssm_constant_shape = ssm_shape[1:]
attn_blocks_constant = torch.full(
(len(blocks0), *attn_shape[1:]), device=DEVICE, fill_value=3.33
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
)
conv_blocks_constant = torch.full(
(len(blocks1), *conv_shape[1:]), device=DEVICE, fill_value=6.66
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
)
ssm_blocks_constant = torch.full(
(len(blocks1), *ssm_shape[1:]), device=DEVICE, fill_value=9.99
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
)
# fill all attention blocks with constant
for layer in [layer_0, layer_1]:
vllm_ctx[layer].kv_cache[0][blocks0, :] = (
attn_blocks_constant.detach().clone()
)
# Fill attention blocks with constants using kv block indices
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
# fill all mamba blocks with constant
for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]:
vllm_ctx[layer].kv_cache[0][0][blocks1, :] = (
conv_blocks_constant.detach().clone()
)
vllm_ctx[layer].kv_cache[0][1][blocks1, :] = (
ssm_blocks_constant.detach().clone()
)
# mamba: kv_cache[0][component][kernel_block_idx, ...]
for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
# verify attention and mamba contents are correct
for layer in [layer_0, layer_1]:
assert torch.equal(
vllm_ctx[layer].kv_cache[0][blocks0, :], attn_blocks_constant
)
for i, kernel_block in enumerate(kernel_blocks_for_attention):
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
expected = attn_blocks_constant[i]
# Check K and V separately
assert torch.equal(actual_kv[0], expected)
assert torch.equal(actual_kv[1], expected)
for layer in [layer_2, layer_3, layer_4, layer_5]:
assert torch.equal(
vllm_ctx[layer].kv_cache[0][0][blocks1, :], conv_blocks_constant
)
assert torch.equal(
vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant
)
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm)
def test_hybrid_block_table_initialization():
"""Test hybrid block table with different kernel and kvcache_manager block
sizes."""
from vllm.v1.worker.block_table import BlockTable
# Test configuration: kvcache_manager block size = 32,
# kernel block size = 16
block_size = 32
kernel_block_sizes = [16]
max_num_reqs = 10
max_num_blocks_per_req = 20
max_num_batched_tokens = 512
block_table = BlockTable(
block_size=block_size,
max_num_reqs=max_num_reqs,
max_num_blocks_per_req=max_num_blocks_per_req,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=False,
device=torch.device(DEVICE),
kernel_block_size=kernel_block_sizes[0],
)
# Verify hybrid block configuration
assert block_table.use_hybrid_blocks is True
assert block_table.block_size == kernel_block_sizes[0]
assert block_table.blocks_per_kv_block == (
block_size // kernel_block_sizes[0]
) # Changed to use first element
# Test block table conversion logic
# One kvcache_manager block should map to multiple kernel blocks
kvcache_manager_blocks = [0, 1, 2]
# Verify that kvcache_manager blocks can be converted to kernel blocks
# and that block table operations work correctly.
req_index = 0
block_table.append_row(kvcache_manager_blocks, req_index)
# Get expected kernel blocks from the implementation for verification.
expected_kernel_blocks = block_table._map_to_kernel_blocks(
np.array(kvcache_manager_blocks)
)
# Verify block table state
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)
assert np.array_equal(
block_table.block_table.np[req_index, : len(expected_kernel_blocks)],
expected_kernel_blocks,
)
def test_input_batch_with_kernel_block_sizes():
"""Test InputBatch initialization with kernel_block_sizes parameter."""
max_num_reqs = 10
max_model_len = 512
max_num_batched_tokens = 512
device = torch.device(DEVICE)
pin_memory = False
vocab_size = 50272
# Test with different kernel block sizes
block_sizes = [32, 64]
kernel_block_sizes = [16, 32]
input_batch = InputBatch(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
device=device,
pin_memory=pin_memory,
vocab_size=vocab_size,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
)
# Verify that block tables were created with kernel block sizes
assert len(input_batch.block_table.block_tables) == len(block_sizes)
for i, (kv_size, kernel_size) in enumerate(zip(block_sizes, kernel_block_sizes)):
block_table = input_batch.block_table.block_tables[i]
if kv_size != kernel_size:
assert block_table.use_hybrid_blocks is True
assert block_table.block_size == kernel_size
else:
assert block_table.use_hybrid_blocks is False
assert block_table.block_size == kernel_size
def test_hybrid_cache_integration(model_runner, dist_init):
"""Test hybrid cache architecture integration with GPUModelRunner."""
# Create a new model runner with hybrid cache configuration
vllm_config = get_vllm_config()
# Configure hybrid cache with different kvcache_manager block size
vllm_config.cache_config.block_size = 32
model_config = vllm_config.model_config
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
num_heads, head_size, 0.1
)
runner = GPUModelRunner(vllm_config, DEVICE)
# Initialize KV cache with configuration
attn_spec = FullAttentionSpec(
block_size=16, # Use kernel block size directly
num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
)
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
kv_cache_config = KVCacheConfig(
num_blocks=NUM_BLOCKS,
kv_cache_tensors=[
KVCacheTensor(size=tensor_size, shared_by=["layer.0"]),
],
kv_cache_groups=[
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
],
)
runner.kv_cache_config = kv_cache_config
# Initialize input batch with kernel block sizes
runner.input_batch = InputBatch(
max_num_reqs=runner.max_num_reqs,
max_model_len=runner.max_model_len,
max_num_batched_tokens=runner.max_num_tokens,
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
kernel_block_sizes=[16],
) # Use kernel block size
runner.initialize_attn_backend(kv_cache_config)
# Verify hybrid block table configuration
block_table = runner.input_batch.block_table.block_tables[0]
assert block_table.block_size == (
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
)
# Test request processing with hybrid blocks
req_id = "hybrid_req_0"
scheduler_output = _schedule_new_request(req_id)
# Update states should work with hybrid blocks
runner._update_states(scheduler_output)
assert _is_req_scheduled(runner, req_id)
assert _is_req_state_block_table_match(runner, req_id)