[NIXL] use Host buffer to support TP_ratio > 1 for XPU (#27140)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
@@ -144,6 +144,8 @@ class XPUPlatform(Platform):
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
vllm_config.kv_transfer_config.enable_permute_local_kv = True
|
||||
|
||||
if parallel_config.distributed_executor_backend is None:
|
||||
if parallel_config.world_size > 1:
|
||||
@@ -245,6 +247,10 @@ class XPUPlatform(Platform):
|
||||
) -> None:
|
||||
"""Copy blocks from src_cache to dst_cache on XPU."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
if _src_cache.shape[2:] != dst_cache.shape[2:]:
|
||||
# To support TP_ratio, HOST KV might be initiated with HND
|
||||
# while XPU device KV is with NHD
|
||||
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
|
||||
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
||||
|
||||
@classmethod
|
||||
@@ -257,4 +263,8 @@ class XPUPlatform(Platform):
|
||||
) -> None:
|
||||
"""Copy blocks from XPU to host (CPU)."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
if _src_cache.shape[2:] != dst_cache.shape[2:]:
|
||||
# XPU device KV is with NHD while HOST KV
|
||||
# might be initiated with HND for TP_ratio support
|
||||
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
|
||||
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
||||
|
||||
Reference in New Issue
Block a user