[Misc] Set default kv_buffer_device in a better way (#36862)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-13 19:07:33 +00:00
committed by GitHub
parent b3ce711b93
commit 5a3f1eb62f
2 changed files with 9 additions and 11 deletions

View File

@@ -13,6 +13,12 @@ KVConsumer = Literal["kv_consumer", "kv_both"]
KVRole = Literal[KVProducer, KVConsumer] KVRole = Literal[KVProducer, KVConsumer]
def kv_buffer_device_default_factory() -> str:
from vllm.platforms import current_platform
return current_platform.device_type
@config @config
class KVTransferConfig: class KVTransferConfig:
"""Configuration for distributed KV cache transfer.""" """Configuration for distributed KV cache transfer."""
@@ -24,7 +30,7 @@ class KVTransferConfig:
engine_id: str | None = None engine_id: str | None = None
"""The engine id for KV transfers.""" """The engine id for KV transfers."""
kv_buffer_device: str | None = None kv_buffer_device: str = field(default_factory=kv_buffer_device_default_factory)
"""The device used by kv connector to buffer the KV cache. Choices are """The device used by kv connector to buffer the KV cache. Choices are
'cuda', 'cpu' and 'xpu'.""" 'cuda', 'cpu' and 'xpu'."""
@@ -100,11 +106,6 @@ class KVTransferConfig:
f"is set, supported roles are {get_args(KVRole)}" f"is set, supported roles are {get_args(KVRole)}"
) )
if self.kv_buffer_device is None:
from vllm.platforms import current_platform
self.kv_buffer_device = current_platform.device_type
@property @property
def is_kv_transfer_instance(self) -> bool: def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and self.kv_role in get_args(KVRole) return self.kv_connector is not None and self.kv_role in get_args(KVRole)

View File

@@ -998,10 +998,7 @@ class NixlConnectorWorker:
# KV Caches and nixl tracking data. # KV Caches and nixl tracking data.
self.device_type = current_platform.device_type self.device_type = current_platform.device_type
kv_buffer_device = vllm_config.kv_transfer_config.kv_buffer_device self.kv_buffer_device: str = vllm_config.kv_transfer_config.kv_buffer_device
if kv_buffer_device is None:
raise ValueError("kv_buffer_device must be set for NixlConnector")
self.kv_buffer_device: str = kv_buffer_device
if self.device_type not in _NIXL_SUPPORTED_DEVICE: if self.device_type not in _NIXL_SUPPORTED_DEVICE:
raise RuntimeError(f"{self.device_type} is not supported.") raise RuntimeError(f"{self.device_type} is not supported.")
elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[self.device_type]: elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[self.device_type]: