From 5a3f1eb62fb8a5d114001488832f8bd7f93df5b8 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 13 Mar 2026 19:07:33 +0000 Subject: [PATCH] [Misc] Set default `kv_buffer_device` in a better way (#36862) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/kv_transfer.py | 15 ++++++++------- .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 +---- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index 172b7a805..b22af99f7 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -13,6 +13,12 @@ KVConsumer = Literal["kv_consumer", "kv_both"] KVRole = Literal[KVProducer, KVConsumer] +def kv_buffer_device_default_factory() -> str: + from vllm.platforms import current_platform + + return current_platform.device_type + + @config class KVTransferConfig: """Configuration for distributed KV cache transfer.""" @@ -24,9 +30,9 @@ class KVTransferConfig: engine_id: str | None = None """The engine id for KV transfers.""" - kv_buffer_device: str | None = None + kv_buffer_device: str = field(default_factory=kv_buffer_device_default_factory) """The device used by kv connector to buffer the KV cache. Choices are - 'cuda','cpu' and 'xpu'.""" + 'cuda', 'cpu' and 'xpu'.""" kv_buffer_size: float = 1e9 """The buffer size for TorchDistributedConnector. Measured in number of @@ -100,11 +106,6 @@ class KVTransferConfig: f"is set, supported roles are {get_args(KVRole)}" ) - if self.kv_buffer_device is None: - from vllm.platforms import current_platform - - self.kv_buffer_device = current_platform.device_type - @property def is_kv_transfer_instance(self) -> bool: return self.kv_connector is not None and self.kv_role in get_args(KVRole) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e6c49d7a0..d381b5270 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -998,10 +998,7 @@ class NixlConnectorWorker: # KV Caches and nixl tracking data. self.device_type = current_platform.device_type - kv_buffer_device = vllm_config.kv_transfer_config.kv_buffer_device - if kv_buffer_device is None: - raise ValueError("kv_buffer_device must be set for NixlConnector") - self.kv_buffer_device: str = kv_buffer_device + self.kv_buffer_device: str = vllm_config.kv_transfer_config.kv_buffer_device if self.device_type not in _NIXL_SUPPORTED_DEVICE: raise RuntimeError(f"{self.device_type} is not supported.") elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[self.device_type]: