[NIXL] Fix after virtual block_size for host_buffer with heter kv_layout (#29122)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
@@ -1042,10 +1042,12 @@ class NixlConnectorWorker:
|
|||||||
NOT directly supported by NIXL (e.g., tpu)
|
NOT directly supported by NIXL (e.g., tpu)
|
||||||
"""
|
"""
|
||||||
xfer_buffers: dict[str, torch.Tensor] = {}
|
xfer_buffers: dict[str, torch.Tensor] = {}
|
||||||
|
inv_order = [0, 1, 3, 2, 4]
|
||||||
try:
|
try:
|
||||||
for layer_name, kv_cache in kv_caches.items():
|
for layer_name, kv_cache in kv_caches.items():
|
||||||
kv_shape = kv_cache.shape
|
kv_shape = kv_cache.shape
|
||||||
kv_dtype = kv_cache.dtype
|
kv_dtype = kv_cache.dtype
|
||||||
|
permute_shape = False
|
||||||
if (
|
if (
|
||||||
self.kv_cache_layout == "NHD"
|
self.kv_cache_layout == "NHD"
|
||||||
and self.vllm_config.kv_transfer_config is not None
|
and self.vllm_config.kv_transfer_config is not None
|
||||||
@@ -1059,10 +1061,20 @@ class NixlConnectorWorker:
|
|||||||
# Since NHD will not support Decode/Prefill TP_ratio > 1,
|
# Since NHD will not support Decode/Prefill TP_ratio > 1,
|
||||||
# we can leverage host_buffer for permute
|
# we can leverage host_buffer for permute
|
||||||
self.host_buffer_kv_cache_layout = "HND"
|
self.host_buffer_kv_cache_layout = "HND"
|
||||||
kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4])
|
kv_shape = (
|
||||||
|
tuple(kv_shape[i] for i in inv_order)
|
||||||
|
if not self.use_mla
|
||||||
|
else kv_shape
|
||||||
|
)
|
||||||
|
permute_shape = not self.use_mla
|
||||||
|
|
||||||
xfer_buffers[layer_name] = torch.empty(
|
xfer_buffers[layer_name] = torch.empty(
|
||||||
kv_shape, dtype=kv_dtype, device="cpu"
|
kv_shape, dtype=kv_dtype, device="cpu"
|
||||||
)
|
)
|
||||||
|
if permute_shape:
|
||||||
|
xfer_buffers[layer_name] = xfer_buffers[layer_name].permute(
|
||||||
|
inv_order
|
||||||
|
)
|
||||||
except MemoryError as e:
|
except MemoryError as e:
|
||||||
logger.error("NIXLConnectorWorker gets %s.", e)
|
logger.error("NIXLConnectorWorker gets %s.", e)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -251,10 +251,6 @@ class XPUPlatform(Platform):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Copy blocks from src_cache to dst_cache on XPU."""
|
"""Copy blocks from src_cache to dst_cache on XPU."""
|
||||||
_src_cache = src_cache[:, src_block_indices]
|
_src_cache = src_cache[:, src_block_indices]
|
||||||
if _src_cache.shape[2:] != dst_cache.shape[2:]:
|
|
||||||
# To support TP_ratio, HOST KV might be initiated with HND
|
|
||||||
# while XPU device KV is with NHD
|
|
||||||
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
|
|
||||||
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -267,8 +263,4 @@ class XPUPlatform(Platform):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Copy blocks from XPU to host (CPU)."""
|
"""Copy blocks from XPU to host (CPU)."""
|
||||||
_src_cache = src_cache[:, src_block_indices]
|
_src_cache = src_cache[:, src_block_indices]
|
||||||
if _src_cache.shape[2:] != dst_cache.shape[2:]:
|
|
||||||
# XPU device KV is with NHD while HOST KV
|
|
||||||
# might be initiated with HND for TP_ratio support
|
|
||||||
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
|
|
||||||
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
||||||
|
|||||||
Reference in New Issue
Block a user