[Core][BugFix] Fix PP KV cache sharding memory validation (#33698)

Signed-off-by: junuxyz <216036880+junuxyz@users.noreply.github.com>
This commit is contained in:
junuxyz
2026-02-11 00:46:24 +09:00
committed by GitHub
parent afdce12c89
commit c5a66d1697
2 changed files with 169 additions and 40 deletions

View File

@@ -1046,6 +1046,99 @@ def test_get_kv_cache_configs_multiple_workers():
)
@pytest.mark.parametrize(
"asymmetric_memory",
[False, True],
ids=["symmetric", "asymmetric"],
)
def test_get_kv_cache_configs_pp_sharding(asymmetric_memory):
model_config = ModelConfig(max_model_len=512)
vllm_config = VllmConfig(model_config=model_config)
ref_kv_cache_spec = new_kv_cache_spec()
pp_kv_cache_specs = [
{"layer1": ref_kv_cache_spec},
{"layer2": ref_kv_cache_spec},
]
expected_num_blocks = model_config.max_model_len // ref_kv_cache_spec.block_size + 1
avail_memory = ref_kv_cache_spec.page_size_bytes * expected_num_blocks
# With per-worker validation, each worker only needs memory for its own
# layers. Worker 2 having more memory shouldn't affect worker 1's config.
available_memory = (
[avail_memory, avail_memory * 2] if asymmetric_memory else [avail_memory] * 2
)
kv_cache_configs = get_kv_cache_configs(
vllm_config,
pp_kv_cache_specs,
available_memory,
)
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=expected_num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=ref_kv_cache_spec.page_size_bytes * expected_num_blocks,
shared_by=["layer1"],
),
],
kv_cache_groups=[KVCacheGroupSpec(["layer1"], ref_kv_cache_spec)],
),
KVCacheConfig(
num_blocks=expected_num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=ref_kv_cache_spec.page_size_bytes * expected_num_blocks,
shared_by=["layer2"],
),
],
kv_cache_groups=[KVCacheGroupSpec(["layer2"], ref_kv_cache_spec)],
),
]
def test_project_kv_cache_groups_to_worker():
spec_a = new_kv_cache_spec()
spec_b = new_kv_cache_spec(num_kv_heads=4)
global_groups = [
KVCacheGroupSpec(["layer1", "layer2", "layer3"], spec_a),
]
worker_spec = {"layer1": spec_a, "layer2": spec_a}
projected = kv_cache_utils._project_kv_cache_groups_to_worker(
global_groups, worker_spec
)
assert len(projected) == 1
assert projected[0].layer_names == ["layer1", "layer2"]
assert projected[0].kv_cache_spec is spec_a
projected = kv_cache_utils._project_kv_cache_groups_to_worker(
global_groups, {"layer4": spec_a}
)
assert len(projected) == 1
assert projected[0].layer_names == []
assert projected[0].kv_cache_spec is spec_a
uniform_spec = UniformTypeKVCacheSpecs(
block_size=16,
kv_cache_specs={"layer1": spec_a, "layer2": spec_b, "layer3": spec_a},
)
global_groups_uniform = [
KVCacheGroupSpec(["layer1", "layer2", "layer3"], uniform_spec),
]
projected = kv_cache_utils._project_kv_cache_groups_to_worker(
global_groups_uniform, {"layer1": spec_a, "layer3": spec_a}
)
assert len(projected) == 1
assert projected[0].layer_names == ["layer1", "layer3"]
proj_spec = projected[0].kv_cache_spec
assert isinstance(proj_spec, UniformTypeKVCacheSpecs)
assert set(proj_spec.kv_cache_specs.keys()) == {"layer1", "layer3"}
def test_merge_kv_cache_spec():
same_layer_specs = [
new_kv_cache_spec(num_kv_heads=32),

View File

@@ -7,6 +7,7 @@ import os
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace
from functools import partial
from typing import Any, NewType, TypeAlias, overload
from vllm import envs
@@ -1390,7 +1391,7 @@ def _estimate_max_model_len_from_groups(
def _auto_fit_max_model_len(
vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec],
projected_groups_per_worker: list[list[KVCacheGroupSpec]],
available_memory: list[int],
) -> None:
"""
@@ -1401,14 +1402,13 @@ def _auto_fit_max_model_len(
Args:
vllm_config: The global VllmConfig (will be modified in-place)
kv_cache_groups: The global KV cache groups (from get_kv_cache_groups).
This correctly accounts for padding in hybrid models.
projected_groups_per_worker: KV cache groups projected to each worker.
available_memory: Memory available for KV cache in bytes for each
worker.
"""
original_max = vllm_config.model_config.max_model_len
if not kv_cache_groups:
if all(not groups for groups in projected_groups_per_worker):
# All workers have empty specs (attention-free model)
logger.info_once(
"Auto-fit max_model_len: attention-free model, "
@@ -1418,11 +1418,16 @@ def _auto_fit_max_model_len(
)
return
# Use minimum available memory across all workers
min_available_memory = min(available_memory)
auto_fit_max = _estimate_max_model_len_from_groups(
vllm_config, kv_cache_groups, min_available_memory
)
# Find the max_model_len that fits across all workers.
auto_fit_max = original_max
limiting_worker_mem = available_memory[0]
for groups, avail_mem in zip(projected_groups_per_worker, available_memory):
if not groups:
continue
worker_max = _estimate_max_model_len_from_groups(vllm_config, groups, avail_mem)
if worker_max < auto_fit_max:
auto_fit_max = worker_max
limiting_worker_mem = avail_mem
if auto_fit_max <= 0:
raise ValueError(
@@ -1446,11 +1451,47 @@ def _auto_fit_max_model_len(
"available GPU memory (%s GiB available for KV cache)",
original_max,
auto_fit_max,
format_gib(min_available_memory),
format_gib(limiting_worker_mem),
scope="local",
)
def _project_kv_cache_groups_to_worker(
global_kv_cache_groups: list[KVCacheGroupSpec],
worker_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
"""
Projects global KV cache groups onto a single worker's assigned layers.
In pipeline parallelism, each worker only owns a subset of layers. This
function filters the global groups to include only layers present on the
given worker, adjusting UniformTypeKVCacheSpecs accordingly.
Args:
global_kv_cache_groups: The global KV cache groups for the whole model.
worker_spec: The KV cache spec of each layer on this worker.
Returns:
The projected KV cache groups containing only this worker's layers.
"""
projected_groups: list[KVCacheGroupSpec] = []
for group in global_kv_cache_groups:
worker_layer_names = [
layer_name for layer_name in group.layer_names if layer_name in worker_spec
]
group_spec = group.kv_cache_spec
if worker_layer_names and isinstance(group_spec, UniformTypeKVCacheSpecs):
group_spec = UniformTypeKVCacheSpecs(
block_size=group_spec.block_size,
kv_cache_specs={
layer_name: group_spec.kv_cache_specs[layer_name]
for layer_name in worker_layer_names
},
)
projected_groups.append(KVCacheGroupSpec(worker_layer_names, group_spec))
return projected_groups
def get_kv_cache_configs(
vllm_config: VllmConfig,
kv_cache_specs: list[dict[str, KVCacheSpec]],
@@ -1468,7 +1509,8 @@ def get_kv_cache_configs(
the whole model.
2. Generate the KV cache groups based on the layer ratio of the whole model.
This also handles spec unification for hybrid models.
3. Handle auto-fit max_model_len and memory checks using the unified specs.
3. Handle auto-fit max_model_len and memory checks using per-worker
projected groups to account for PP sharding.
4. Generate the KV cache configs for each worker based on the KV cache
grouping strategy. (This is reasonable because the layer ratio of
different PP stages are similar.)
@@ -1506,44 +1548,38 @@ def get_kv_cache_configs(
# If original_max_model_len was -1, automatically
# determine the maximum model length that fits in available GPU memory.
# We use the global groups here to correctly account for padding.
if vllm_config.model_config.original_max_model_len == -1:
_auto_fit_max_model_len(vllm_config, global_kv_cache_groups, available_memory)
# We use per-worker projected groups to account for PP sharding.
projected_groups_per_worker = [
_project_kv_cache_groups_to_worker(global_kv_cache_groups, worker_spec)
for worker_spec in kv_cache_specs
]
# Check if the available memory is enough (using min across all workers).
# We use the global groups to correctly account for padding.
if global_kv_cache_groups:
if vllm_config.model_config.original_max_model_len == -1:
_auto_fit_max_model_len(
vllm_config, projected_groups_per_worker, available_memory
)
# Check if the available memory is enough per worker.
for groups, avail_mem in zip(projected_groups_per_worker, available_memory):
if not groups:
continue
_check_enough_kv_cache_memory(
min(available_memory),
lambda: _max_memory_usage_bytes_from_groups(
vllm_config, global_kv_cache_groups
),
avail_mem,
partial(_max_memory_usage_bytes_from_groups, vllm_config, groups),
vllm_config.model_config.max_model_len,
lambda am: _estimate_max_model_len_from_groups(
vllm_config, global_kv_cache_groups, am
),
partial(_estimate_max_model_len_from_groups, vllm_config, groups),
)
kv_cache_configs: list[KVCacheConfig] = []
for kv_cache_spec_one_worker, available_memory_one_worker in zip(
kv_cache_specs, available_memory
for projected_groups, kv_cache_spec_one_worker, available_memory_one_worker in zip(
projected_groups_per_worker, kv_cache_specs, available_memory
):
kv_cache_groups_one_worker: list[KVCacheGroupSpec] = []
for group in global_kv_cache_groups:
group_layer_names_one_worker = [
layer_name
for layer_name in group.layer_names
if layer_name in kv_cache_spec_one_worker
]
kv_cache_groups_one_worker.append(
KVCacheGroupSpec(group_layer_names_one_worker, group.kv_cache_spec)
)
assert sum(
len(group.layer_names) for group in kv_cache_groups_one_worker
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
assert sum(len(group.layer_names) for group in projected_groups) == len(
kv_cache_spec_one_worker
), "Some layers are not assigned to any group."
kv_cache_configs.append(
get_kv_cache_config_from_groups(
vllm_config, kv_cache_groups_one_worker, available_memory_one_worker
vllm_config, projected_groups, available_memory_one_worker
)
)