[UX] Use kv_offloading_backend=native by default (#32421)

Signed-off-by: mgoin <mgoin64@gmail.com>
(cherry picked from commit 1be5a73571)
This commit is contained in:
Michael Goin
2026-01-15 13:55:11 -05:00
committed by Kevin H. Luu
parent 6ac0fcf416
commit 0e31fc7996
4 changed files with 28 additions and 15 deletions

View File

@@ -19,7 +19,8 @@ pytestmark = pytest.mark.cpu_test
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
(None, None, 1, 1, None, None),
# When kv_offloading_size is None, offloading is disabled (backend is ignored)
("native", None, 1, 1, None, None),
],
)
def test_kv_connector(
@@ -62,3 +63,19 @@ def test_kv_connector(
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
# Existing config should be replaced
assert "existing_key" not in kv_connector_extra_config
def test_kv_offloading_size_only_uses_native_default():
"""Test that setting only kv_offloading_size enables native offloading."""
vllm_config = VllmConfig(
cache_config=CacheConfig(
kv_offloading_size=4.0,
# kv_offloading_backend not set, should default to "native"
),
)
kv_transfer_config = vllm_config.kv_transfer_config
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
assert kv_transfer_config.kv_connector == "OffloadingConnector"
assert kv_transfer_config.kv_role == "kv_both"
assert kv_connector_extra_config["cpu_bytes_to_use"] == 4.0 * (1 << 30)

View File

@@ -152,13 +152,13 @@ class CacheConfig:
kv_offloading_size: float | None = None
"""Size of the KV cache offloading buffer in GiB. When TP > 1, this is
the total buffer size summed across all TP ranks. By default, this is set
to None, which means no KV offloading is enabled. When set with
kv_offloading_backend, vLLM will enable KV cache offloading to CPU"""
to None, which means no KV offloading is enabled. When set, vLLM will
enable KV cache offloading to CPU using the kv_offloading_backend."""
kv_offloading_backend: KVOffloadingBackend | None = None
kv_offloading_backend: KVOffloadingBackend = "native"
"""The backend to use for KV cache offloading. Supported backends include
'native' (vLLM native CPU offloading), 'lmcache' This option must be used
together with kv_offloading_size."""
'native' (vLLM native CPU offloading), 'lmcache'.
KV offloading is only activated when kv_offloading_size is set."""
def compute_hash(self) -> str:
"""

View File

@@ -498,17 +498,15 @@ class VllmConfig:
Right now, this function reads the offloading settings from
CacheConfig and configures the KVTransferConfig accordingly.
"""
if (kv_offloading_backend := self.cache_config.kv_offloading_backend) is None:
# KV offloading is only activated when kv_offloading_size is set.
if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
return
kv_offloading_backend = self.cache_config.kv_offloading_backend
# If no KVTransferConfig is provided, create a default one.
if self.kv_transfer_config is None:
self.kv_transfer_config = KVTransferConfig()
if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
raise ValueError(
"You must set kv_offloading_size when kv_offloading_backend is set."
)
num_kv_ranks = (
self.parallel_config.tensor_parallel_size
* self.parallel_config.pipeline_parallel_size

View File

@@ -578,9 +578,7 @@ class EngineArgs:
optimization_level: OptimizationLevel = VllmConfig.optimization_level
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
kv_offloading_backend: KVOffloadingBackend | None = (
CacheConfig.kv_offloading_backend
)
kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
tokens_only: bool = False
def __post_init__(self):