[XPU] Set consistent default KV cache layout (#24745)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-09-15 12:09:34 +02:00
committed by GitHub
parent bc0f6059a2
commit 2e41f5abca
3 changed files with 23 additions and 16 deletions

View File

@@ -56,9 +56,9 @@ except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
# Supported xPUs and types of kv transfer buffer.
# {xPU: tuple of supported kv buffer types}
_NIXL_SUPPORTED_XPUS = {
# Supported platforms and types of kv transfer buffer.
# {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_DEVICE = {
"cuda": ("cuda", ),
"tpu": ("cpu", ),
"xpu": ("cpu", ),
@@ -458,9 +458,9 @@ class NixlConnectorWorker:
self.device_type = current_platform.device_type
self.kv_buffer_device: str = \
vllm_config.kv_transfer_config.kv_buffer_device
if self.device_type not in _NIXL_SUPPORTED_XPUS:
if self.device_type not in _NIXL_SUPPORTED_DEVICE:
raise RuntimeError(f"{self.device_type} is not supported.")
elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[
elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[
self.device_type]:
raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
@@ -468,7 +468,7 @@ class NixlConnectorWorker:
self.device_kv_caches: dict[str, torch.Tensor] = {}
# cpu kv buffer for xfer
# used when xPU memory can not be registered under nixl
# used when device memory can not be registered under nixl
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
self.use_host_buffer = self.kv_buffer_device == "cpu"
if self.kv_buffer_device == "cuda":
@@ -927,6 +927,9 @@ class NixlConnectorWorker:
if tp_ratio > 1:
# Heterogeneous TP expects same kv_cache_layout.
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
if self.device_type == "xpu":
raise ValueError(
"Heterogeneous TP is not supported on XPU")
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "