[XPU] Set consistent default KV cache layout (#24745)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -56,9 +56,9 @@ except ImportError:
|
||||
logger.warning("NIXL is not available")
|
||||
NixlWrapper = None
|
||||
|
||||
# Supported xPUs and types of kv transfer buffer.
|
||||
# {xPU: tuple of supported kv buffer types}
|
||||
_NIXL_SUPPORTED_XPUS = {
|
||||
# Supported platforms and types of kv transfer buffer.
|
||||
# {device: tuple of supported kv buffer types}
|
||||
_NIXL_SUPPORTED_DEVICE = {
|
||||
"cuda": ("cuda", ),
|
||||
"tpu": ("cpu", ),
|
||||
"xpu": ("cpu", ),
|
||||
@@ -458,9 +458,9 @@ class NixlConnectorWorker:
|
||||
self.device_type = current_platform.device_type
|
||||
self.kv_buffer_device: str = \
|
||||
vllm_config.kv_transfer_config.kv_buffer_device
|
||||
if self.device_type not in _NIXL_SUPPORTED_XPUS:
|
||||
if self.device_type not in _NIXL_SUPPORTED_DEVICE:
|
||||
raise RuntimeError(f"{self.device_type} is not supported.")
|
||||
elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[
|
||||
elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[
|
||||
self.device_type]:
|
||||
raise RuntimeError(
|
||||
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
|
||||
@@ -468,7 +468,7 @@ class NixlConnectorWorker:
|
||||
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
# cpu kv buffer for xfer
|
||||
# used when xPU memory can not be registered under nixl
|
||||
# used when device memory can not be registered under nixl
|
||||
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
|
||||
self.use_host_buffer = self.kv_buffer_device == "cpu"
|
||||
if self.kv_buffer_device == "cuda":
|
||||
@@ -927,6 +927,9 @@ class NixlConnectorWorker:
|
||||
if tp_ratio > 1:
|
||||
# Heterogeneous TP expects same kv_cache_layout.
|
||||
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
|
||||
if self.device_type == "xpu":
|
||||
raise ValueError(
|
||||
"Heterogeneous TP is not supported on XPU")
|
||||
|
||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
|
||||
Reference in New Issue
Block a user