[Bugfix] Missing NIXL metadata for handshake initialization if instance spans multi-node (#26338)

Signed-off-by: Guan Luo <gluo@nvidia.com>
Signed-off-by: GuanLuo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
GuanLuo
2025-11-01 01:16:00 +08:00
committed by GitHub
parent 7e06c40e63
commit d6517be3cd
7 changed files with 321 additions and 95 deletions

View File

@@ -19,7 +19,11 @@ from vllm.distributed import (
init_distributed_environment,
set_custom_all_reduce,
)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.parallel_state import (
get_pp_group,
get_tp_group,
@@ -348,6 +352,21 @@ class Worker(WorkerBase):
return int(self.available_kv_cache_memory_bytes)
def get_kv_connector_handshake_metadata(self) -> dict | None:
"""Get KV connector metadata from this worker if available."""
if not has_kv_transfer_group():
return None
connector = get_kv_transfer_group()
# Return None for connectors that don't need to exchange handshake
# metadata across workers.
if (metadata := connector.get_handshake_metadata()) is None:
return None
tp_rank = get_tp_group().rank_in_group
return {tp_rank: metadata}
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()