[Hybrid Allocator] Support Pipeline Parallel (#23974)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-09-14 15:55:17 -07:00
committed by GitHub
parent 90f3f7d73e
commit 8e5cdcda4e
7 changed files with 472 additions and 235 deletions

View File

@@ -18,13 +18,12 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.kv_cache_utils import (
BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
get_kv_cache_configs, get_max_concurrency_for_kv_cache_config,
get_request_block_hasher, hash_block_tokens, init_none_hash,
is_kv_cache_type_uniform, make_block_hash_with_group_id,
unify_kv_cache_configs)
is_kv_cache_type_uniform, make_block_hash_with_group_id)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
@@ -531,102 +530,288 @@ def test_metrics():
assert not metrics.query_queue
def test_unify_kv_cache_configs():
same_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
KVCacheConfig(
num_blocks=20,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
]
unify_kv_cache_configs(same_kv_cache_config)
assert same_kv_cache_config[0].num_blocks == 10
assert same_kv_cache_config[1].num_blocks == 10
def test_get_kv_cache_configs_multiple_workers():
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config)
need_sort_kv_cache_config = [
ref_kv_cache_spec = new_kv_cache_spec()
same_kv_cache_specs = [{
"layer1": new_kv_cache_spec(),
"layer2": new_kv_cache_spec(),
}, {
"layer1": new_kv_cache_spec(),
"layer2": new_kv_cache_spec(),
}]
# Basic case. All things are the same.
kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
KVCacheConfig(
num_blocks=20,
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
]
unify_kv_cache_configs(need_sort_kv_cache_config)
sorted_kv_cache_groups = [
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"], new_kv_cache_spec(num_kv_heads=4)),
]
assert (
need_sort_kv_cache_config[0].kv_cache_groups == sorted_kv_cache_groups)
assert (
need_sort_kv_cache_config[1].kv_cache_groups == sorted_kv_cache_groups)
diff_kv_cache_config = [
# Different available memory. This is the case for TP.
# Use the smallest memory available.
kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 20
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
KVCacheConfig(
num_blocks=20,
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=8)),
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
]
# Different KV cache specs. This is the case for PP.
different_layer_specs = [{
"layer1": new_kv_cache_spec(),
}, {
"layer2": new_kv_cache_spec(),
"layer3": new_kv_cache_spec(),
}]
# Different workers have different layers.
kv_cache_configs = get_kv_cache_configs(
vllm_config, different_layer_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
shared_by=["layer1"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer3"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer2", "layer3"], new_kv_cache_spec()),
],
),
]
# Some layers are the same, some are different. This is the case for TP+PP
tp_pp_kv_cache_specs = [{
"layer1": new_kv_cache_spec(),
"layer2": new_kv_cache_spec(),
}, {
"layer1": new_kv_cache_spec(),
"layer2": new_kv_cache_spec(),
}, {
"layer3": new_kv_cache_spec(),
}, {
"layer3": new_kv_cache_spec(),
}]
kv_cache_configs = get_kv_cache_configs(
vllm_config, tp_pp_kv_cache_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10,
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
shared_by=["layer3"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer3"], ref_kv_cache_spec),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
shared_by=["layer3"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer3"], ref_kv_cache_spec),
],
),
]
# Different workers have different types of layers. This is the case for
# hybrid models + PP.
different_type_layer_specs = [{
"layer1": new_kv_cache_spec(),
"layer2": new_kv_cache_spec(),
}, {
"layer3": new_sliding_window_spec(),
"layer4": new_sliding_window_spec(),
}]
kv_cache_configs = get_kv_cache_configs(
vllm_config, different_type_layer_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10,
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
KVCacheGroupSpec([], new_sliding_window_spec()),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer3"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer4"]),
],
kv_cache_groups=[
KVCacheGroupSpec([], ref_kv_cache_spec),
KVCacheGroupSpec(["layer3", "layer4"],
new_sliding_window_spec()),
],
),
]
# When divided into multiple KVCacheGroups, need to ensure the number of
# layers per group is similar.
different_type_layer_specs = [{
"layer1": new_kv_cache_spec(),
"layer2": new_sliding_window_spec(),
"layer3": new_sliding_window_spec(),
}, {
"layer4": new_kv_cache_spec(),
"layer5": new_sliding_window_spec(),
"layer6": new_sliding_window_spec(),
}]
kv_cache_configs = get_kv_cache_configs(
vllm_config, different_type_layer_specs, [
ref_kv_cache_spec.page_size_bytes * 10,
ref_kv_cache_spec.page_size_bytes * 10,
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1", "layer2", "layer3"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], ref_kv_cache_spec),
KVCacheGroupSpec(["layer2"], new_sliding_window_spec()),
KVCacheGroupSpec(["layer3"], new_sliding_window_spec()),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer4", "layer5", "layer6"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer4"], ref_kv_cache_spec),
KVCacheGroupSpec(["layer5"], new_sliding_window_spec()),
KVCacheGroupSpec(["layer6"], new_sliding_window_spec()),
],
),
]
# Have conflicting layers. Need to raise an error.
conflicting_layer_specs = [{
"layer1": new_kv_cache_spec(),
}, {
"layer1": new_sliding_window_spec(),
}]
with pytest.raises(AssertionError):
unify_kv_cache_configs(diff_kv_cache_config)
get_kv_cache_configs(vllm_config, conflicting_layer_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10,
])
def test_merge_kv_cache_spec():
@@ -890,7 +1075,7 @@ def test_allocate_with_lookahead():
assert len(blocks.get_block_ids()[0]) == 2
def test_get_kv_cache_config():
def test_get_kv_cache_config_one_worker():
# pass max_model_len to pass check_enough_kv_cache_memory
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config)
@@ -901,8 +1086,10 @@ def test_get_kv_cache_config():
'layer_1': new_kv_cache_spec(),
'layer_2': new_kv_cache_spec(),
}
kv_cache_config_full = get_kv_cache_config(
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
kv_cache_config_full = get_kv_cache_configs(
vllm_config, [kv_cache_specs_full],
[mem_per_block_per_layer * 2 * 32])[0]
print(kv_cache_config_full)
assert kv_cache_config_full == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
@@ -920,8 +1107,9 @@ def test_get_kv_cache_config():
'layer_1': new_sliding_window_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_sliding = get_kv_cache_config(
vllm_config, kv_cache_specs_sliding, mem_per_block_per_layer * 2 * 32)
kv_cache_config_sliding = get_kv_cache_configs(
vllm_config, [kv_cache_specs_sliding],
[mem_per_block_per_layer * 2 * 32])[0]
assert kv_cache_config_sliding == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
@@ -940,8 +1128,9 @@ def test_get_kv_cache_config():
'layer_1': new_kv_cache_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 2 * 32])[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
@@ -962,8 +1151,9 @@ def test_get_kv_cache_config():
'layer_1': new_kv_cache_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 2 * 32])[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=64,
kv_cache_tensors=[
@@ -985,21 +1175,22 @@ def test_get_kv_cache_config():
'layer_5': new_sliding_window_spec(),
'layer_6': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 2 * 32])[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_1", "layer_3", "layer_5"]),
shared_by=["layer_1", "layer_3", "layer_4"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2", "layer_4", "layer_6"]),
shared_by=["layer_2", "layer_5", "layer_6"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer_3", "layer_4"],
KVCacheGroupSpec(["layer_3", "layer_5"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_5", "layer_6"],
KVCacheGroupSpec(["layer_4", "layer_6"],
new_sliding_window_spec()),
],
)
@@ -1017,27 +1208,30 @@ def test_get_kv_cache_config():
'layer_9': new_sliding_window_spec(),
'layer_10': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 3 * 32)
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 3 * 32])[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(
size=mem_per_block_per_layer * 32,
shared_by=["layer_1", "layer_4", "layer_7", "layer_10"]),
shared_by=["layer_1", "layer_4", "layer_5", "layer_6"]),
KVCacheTensor(
size=mem_per_block_per_layer * 32,
shared_by=["layer_2", "layer_7", "layer_8", "layer_9"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2", "layer_5", "layer_8"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_3", "layer_6", "layer_9"]),
shared_by=["layer_3", "layer_10"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"],
new_kv_cache_spec()),
KVCacheGroupSpec(["layer_4", "layer_5", "layer_6"],
KVCacheGroupSpec(["layer_4", "layer_7", "layer_10"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_7", "layer_8", "layer_9"],
KVCacheGroupSpec(["layer_5", "layer_8"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_6", "layer_9"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_10"], new_sliding_window_spec()),
],
)
@@ -1047,13 +1241,14 @@ def test_get_kv_cache_config():
'layer_2': new_kv_cache_spec(),
}
with pytest.raises(NotImplementedError):
get_kv_cache_config(vllm_config, kv_cache_specs_hybrid,
mem_per_block_per_layer * 2 * 32)
get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 2 * 32])[0]
# Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16
kv_cache_config_override_blocks = get_kv_cache_config(
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
kv_cache_config_override_blocks = get_kv_cache_configs(
vllm_config, [kv_cache_specs_full],
[mem_per_block_per_layer * 2 * 32])[0]
assert kv_cache_config_override_blocks == KVCacheConfig(
num_blocks=16,
kv_cache_tensors=[
@@ -1065,3 +1260,16 @@ def test_get_kv_cache_config():
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
])
def test_get_kv_cache_configs_attention_free():
kv_cache_specs: dict[str, KVCacheSpec] = {}
vllm_config = VllmConfig(model_config=ModelConfig(max_model_len=16))
kv_cache_configs = get_kv_cache_configs(vllm_config, [kv_cache_specs], [0])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=1,
kv_cache_tensors=[],
kv_cache_groups=[],
)
]