[Core][BugFix] Fix PP KV cache sharding memory validation (#33698)
Signed-off-by: junuxyz <216036880+junuxyz@users.noreply.github.com>
This commit is contained in:
@@ -1046,6 +1046,99 @@ def test_get_kv_cache_configs_multiple_workers():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"asymmetric_memory",
|
||||
[False, True],
|
||||
ids=["symmetric", "asymmetric"],
|
||||
)
|
||||
def test_get_kv_cache_configs_pp_sharding(asymmetric_memory):
|
||||
model_config = ModelConfig(max_model_len=512)
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
|
||||
ref_kv_cache_spec = new_kv_cache_spec()
|
||||
pp_kv_cache_specs = [
|
||||
{"layer1": ref_kv_cache_spec},
|
||||
{"layer2": ref_kv_cache_spec},
|
||||
]
|
||||
|
||||
expected_num_blocks = model_config.max_model_len // ref_kv_cache_spec.block_size + 1
|
||||
avail_memory = ref_kv_cache_spec.page_size_bytes * expected_num_blocks
|
||||
|
||||
# With per-worker validation, each worker only needs memory for its own
|
||||
# layers. Worker 2 having more memory shouldn't affect worker 1's config.
|
||||
available_memory = (
|
||||
[avail_memory, avail_memory * 2] if asymmetric_memory else [avail_memory] * 2
|
||||
)
|
||||
|
||||
kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config,
|
||||
pp_kv_cache_specs,
|
||||
available_memory,
|
||||
)
|
||||
|
||||
assert kv_cache_configs == [
|
||||
KVCacheConfig(
|
||||
num_blocks=expected_num_blocks,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=ref_kv_cache_spec.page_size_bytes * expected_num_blocks,
|
||||
shared_by=["layer1"],
|
||||
),
|
||||
],
|
||||
kv_cache_groups=[KVCacheGroupSpec(["layer1"], ref_kv_cache_spec)],
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=expected_num_blocks,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=ref_kv_cache_spec.page_size_bytes * expected_num_blocks,
|
||||
shared_by=["layer2"],
|
||||
),
|
||||
],
|
||||
kv_cache_groups=[KVCacheGroupSpec(["layer2"], ref_kv_cache_spec)],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_project_kv_cache_groups_to_worker():
|
||||
spec_a = new_kv_cache_spec()
|
||||
spec_b = new_kv_cache_spec(num_kv_heads=4)
|
||||
|
||||
global_groups = [
|
||||
KVCacheGroupSpec(["layer1", "layer2", "layer3"], spec_a),
|
||||
]
|
||||
worker_spec = {"layer1": spec_a, "layer2": spec_a}
|
||||
projected = kv_cache_utils._project_kv_cache_groups_to_worker(
|
||||
global_groups, worker_spec
|
||||
)
|
||||
assert len(projected) == 1
|
||||
assert projected[0].layer_names == ["layer1", "layer2"]
|
||||
assert projected[0].kv_cache_spec is spec_a
|
||||
|
||||
projected = kv_cache_utils._project_kv_cache_groups_to_worker(
|
||||
global_groups, {"layer4": spec_a}
|
||||
)
|
||||
assert len(projected) == 1
|
||||
assert projected[0].layer_names == []
|
||||
assert projected[0].kv_cache_spec is spec_a
|
||||
|
||||
uniform_spec = UniformTypeKVCacheSpecs(
|
||||
block_size=16,
|
||||
kv_cache_specs={"layer1": spec_a, "layer2": spec_b, "layer3": spec_a},
|
||||
)
|
||||
global_groups_uniform = [
|
||||
KVCacheGroupSpec(["layer1", "layer2", "layer3"], uniform_spec),
|
||||
]
|
||||
projected = kv_cache_utils._project_kv_cache_groups_to_worker(
|
||||
global_groups_uniform, {"layer1": spec_a, "layer3": spec_a}
|
||||
)
|
||||
assert len(projected) == 1
|
||||
assert projected[0].layer_names == ["layer1", "layer3"]
|
||||
proj_spec = projected[0].kv_cache_spec
|
||||
assert isinstance(proj_spec, UniformTypeKVCacheSpecs)
|
||||
assert set(proj_spec.kv_cache_specs.keys()) == {"layer1", "layer3"}
|
||||
|
||||
|
||||
def test_merge_kv_cache_spec():
|
||||
same_layer_specs = [
|
||||
new_kv_cache_spec(num_kv_heads=32),
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
from dataclasses import dataclass, replace
|
||||
from functools import partial
|
||||
from typing import Any, NewType, TypeAlias, overload
|
||||
|
||||
from vllm import envs
|
||||
@@ -1390,7 +1391,7 @@ def _estimate_max_model_len_from_groups(
|
||||
|
||||
def _auto_fit_max_model_len(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
projected_groups_per_worker: list[list[KVCacheGroupSpec]],
|
||||
available_memory: list[int],
|
||||
) -> None:
|
||||
"""
|
||||
@@ -1401,14 +1402,13 @@ def _auto_fit_max_model_len(
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig (will be modified in-place)
|
||||
kv_cache_groups: The global KV cache groups (from get_kv_cache_groups).
|
||||
This correctly accounts for padding in hybrid models.
|
||||
projected_groups_per_worker: KV cache groups projected to each worker.
|
||||
available_memory: Memory available for KV cache in bytes for each
|
||||
worker.
|
||||
"""
|
||||
original_max = vllm_config.model_config.max_model_len
|
||||
|
||||
if not kv_cache_groups:
|
||||
if all(not groups for groups in projected_groups_per_worker):
|
||||
# All workers have empty specs (attention-free model)
|
||||
logger.info_once(
|
||||
"Auto-fit max_model_len: attention-free model, "
|
||||
@@ -1418,11 +1418,16 @@ def _auto_fit_max_model_len(
|
||||
)
|
||||
return
|
||||
|
||||
# Use minimum available memory across all workers
|
||||
min_available_memory = min(available_memory)
|
||||
auto_fit_max = _estimate_max_model_len_from_groups(
|
||||
vllm_config, kv_cache_groups, min_available_memory
|
||||
)
|
||||
# Find the max_model_len that fits across all workers.
|
||||
auto_fit_max = original_max
|
||||
limiting_worker_mem = available_memory[0]
|
||||
for groups, avail_mem in zip(projected_groups_per_worker, available_memory):
|
||||
if not groups:
|
||||
continue
|
||||
worker_max = _estimate_max_model_len_from_groups(vllm_config, groups, avail_mem)
|
||||
if worker_max < auto_fit_max:
|
||||
auto_fit_max = worker_max
|
||||
limiting_worker_mem = avail_mem
|
||||
|
||||
if auto_fit_max <= 0:
|
||||
raise ValueError(
|
||||
@@ -1446,11 +1451,47 @@ def _auto_fit_max_model_len(
|
||||
"available GPU memory (%s GiB available for KV cache)",
|
||||
original_max,
|
||||
auto_fit_max,
|
||||
format_gib(min_available_memory),
|
||||
format_gib(limiting_worker_mem),
|
||||
scope="local",
|
||||
)
|
||||
|
||||
|
||||
def _project_kv_cache_groups_to_worker(
|
||||
global_kv_cache_groups: list[KVCacheGroupSpec],
|
||||
worker_spec: dict[str, KVCacheSpec],
|
||||
) -> list[KVCacheGroupSpec]:
|
||||
"""
|
||||
Projects global KV cache groups onto a single worker's assigned layers.
|
||||
|
||||
In pipeline parallelism, each worker only owns a subset of layers. This
|
||||
function filters the global groups to include only layers present on the
|
||||
given worker, adjusting UniformTypeKVCacheSpecs accordingly.
|
||||
|
||||
Args:
|
||||
global_kv_cache_groups: The global KV cache groups for the whole model.
|
||||
worker_spec: The KV cache spec of each layer on this worker.
|
||||
|
||||
Returns:
|
||||
The projected KV cache groups containing only this worker's layers.
|
||||
"""
|
||||
projected_groups: list[KVCacheGroupSpec] = []
|
||||
for group in global_kv_cache_groups:
|
||||
worker_layer_names = [
|
||||
layer_name for layer_name in group.layer_names if layer_name in worker_spec
|
||||
]
|
||||
group_spec = group.kv_cache_spec
|
||||
if worker_layer_names and isinstance(group_spec, UniformTypeKVCacheSpecs):
|
||||
group_spec = UniformTypeKVCacheSpecs(
|
||||
block_size=group_spec.block_size,
|
||||
kv_cache_specs={
|
||||
layer_name: group_spec.kv_cache_specs[layer_name]
|
||||
for layer_name in worker_layer_names
|
||||
},
|
||||
)
|
||||
projected_groups.append(KVCacheGroupSpec(worker_layer_names, group_spec))
|
||||
return projected_groups
|
||||
|
||||
|
||||
def get_kv_cache_configs(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_specs: list[dict[str, KVCacheSpec]],
|
||||
@@ -1468,7 +1509,8 @@ def get_kv_cache_configs(
|
||||
the whole model.
|
||||
2. Generate the KV cache groups based on the layer ratio of the whole model.
|
||||
This also handles spec unification for hybrid models.
|
||||
3. Handle auto-fit max_model_len and memory checks using the unified specs.
|
||||
3. Handle auto-fit max_model_len and memory checks using per-worker
|
||||
projected groups to account for PP sharding.
|
||||
4. Generate the KV cache configs for each worker based on the KV cache
|
||||
grouping strategy. (This is reasonable because the layer ratio of
|
||||
different PP stages are similar.)
|
||||
@@ -1506,44 +1548,38 @@ def get_kv_cache_configs(
|
||||
|
||||
# If original_max_model_len was -1, automatically
|
||||
# determine the maximum model length that fits in available GPU memory.
|
||||
# We use the global groups here to correctly account for padding.
|
||||
if vllm_config.model_config.original_max_model_len == -1:
|
||||
_auto_fit_max_model_len(vllm_config, global_kv_cache_groups, available_memory)
|
||||
# We use per-worker projected groups to account for PP sharding.
|
||||
projected_groups_per_worker = [
|
||||
_project_kv_cache_groups_to_worker(global_kv_cache_groups, worker_spec)
|
||||
for worker_spec in kv_cache_specs
|
||||
]
|
||||
|
||||
# Check if the available memory is enough (using min across all workers).
|
||||
# We use the global groups to correctly account for padding.
|
||||
if global_kv_cache_groups:
|
||||
if vllm_config.model_config.original_max_model_len == -1:
|
||||
_auto_fit_max_model_len(
|
||||
vllm_config, projected_groups_per_worker, available_memory
|
||||
)
|
||||
|
||||
# Check if the available memory is enough per worker.
|
||||
for groups, avail_mem in zip(projected_groups_per_worker, available_memory):
|
||||
if not groups:
|
||||
continue
|
||||
_check_enough_kv_cache_memory(
|
||||
min(available_memory),
|
||||
lambda: _max_memory_usage_bytes_from_groups(
|
||||
vllm_config, global_kv_cache_groups
|
||||
),
|
||||
avail_mem,
|
||||
partial(_max_memory_usage_bytes_from_groups, vllm_config, groups),
|
||||
vllm_config.model_config.max_model_len,
|
||||
lambda am: _estimate_max_model_len_from_groups(
|
||||
vllm_config, global_kv_cache_groups, am
|
||||
),
|
||||
partial(_estimate_max_model_len_from_groups, vllm_config, groups),
|
||||
)
|
||||
|
||||
kv_cache_configs: list[KVCacheConfig] = []
|
||||
for kv_cache_spec_one_worker, available_memory_one_worker in zip(
|
||||
kv_cache_specs, available_memory
|
||||
for projected_groups, kv_cache_spec_one_worker, available_memory_one_worker in zip(
|
||||
projected_groups_per_worker, kv_cache_specs, available_memory
|
||||
):
|
||||
kv_cache_groups_one_worker: list[KVCacheGroupSpec] = []
|
||||
for group in global_kv_cache_groups:
|
||||
group_layer_names_one_worker = [
|
||||
layer_name
|
||||
for layer_name in group.layer_names
|
||||
if layer_name in kv_cache_spec_one_worker
|
||||
]
|
||||
kv_cache_groups_one_worker.append(
|
||||
KVCacheGroupSpec(group_layer_names_one_worker, group.kv_cache_spec)
|
||||
)
|
||||
assert sum(
|
||||
len(group.layer_names) for group in kv_cache_groups_one_worker
|
||||
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
|
||||
assert sum(len(group.layer_names) for group in projected_groups) == len(
|
||||
kv_cache_spec_one_worker
|
||||
), "Some layers are not assigned to any group."
|
||||
kv_cache_configs.append(
|
||||
get_kv_cache_config_from_groups(
|
||||
vllm_config, kv_cache_groups_one_worker, available_memory_one_worker
|
||||
vllm_config, projected_groups, available_memory_one_worker
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user