[Bugfix] Missing NIXL metadata for handshake initialization if instance spans multi-node (#26338)
Signed-off-by: Guan Luo <gluo@nvidia.com> Signed-off-by: GuanLuo <41310872+GuanLuo@users.noreply.github.com> Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
|
|||||||
- Default: 5600
|
- Default: 5600
|
||||||
- **Required for both prefiller and decoder instances**
|
- **Required for both prefiller and decoder instances**
|
||||||
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
|
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
|
||||||
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node).
|
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node).
|
||||||
- Used for the initial NIXL handshake between the prefiller and the decoder
|
- Used for the initial NIXL handshake between the prefiller and the decoder
|
||||||
|
|
||||||
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication
|
- `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
|||||||
NixlAgentMetadata,
|
NixlAgentMetadata,
|
||||||
NixlConnector,
|
NixlConnector,
|
||||||
NixlConnectorMetadata,
|
NixlConnectorMetadata,
|
||||||
|
NixlConnectorScheduler,
|
||||||
NixlConnectorWorker,
|
NixlConnectorWorker,
|
||||||
NixlKVConnectorStats,
|
NixlKVConnectorStats,
|
||||||
)
|
)
|
||||||
@@ -283,6 +284,92 @@ def test_prompt_less_than_block_size():
|
|||||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||||
|
FakeNixlWrapper,
|
||||||
|
)
|
||||||
|
def test_kv_transfer_handshake(dist_init):
|
||||||
|
"""Unit test for basic NixlConnector interface functionality."""
|
||||||
|
|
||||||
|
# Test setup, we creates a scheduler that contains a NixlConnector
|
||||||
|
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
|
||||||
|
# all workers of the instance.
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
# in case the test runs on non-GPU machine
|
||||||
|
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
# Create two NixlConnector of role WORKER, one is the worker of
|
||||||
|
# the scheduler (prefill), the other is a worker of decode instance.
|
||||||
|
|
||||||
|
# Prefill connector will register KV cache to populate proper handshake
|
||||||
|
# metadata.
|
||||||
|
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
|
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||||
|
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||||
|
)
|
||||||
|
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||||
|
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||||
|
kv_caches = {
|
||||||
|
"layer0": shared_tensor,
|
||||||
|
"layer1": unique_tensor,
|
||||||
|
"layer2": shared_tensor,
|
||||||
|
}
|
||||||
|
prefill_connector.register_kv_caches(kv_caches)
|
||||||
|
|
||||||
|
# Simulate EngineCore initialization that would
|
||||||
|
# gather connector metadata from all workers, the scheduler connector
|
||||||
|
# expects metadata to be in dict[int, KVConnectorHandshakeMetadata],
|
||||||
|
# where the first key is the dp_rank, the second key is the tp_rank.
|
||||||
|
metadata = {0: prefill_connector.get_handshake_metadata()}
|
||||||
|
scheduler_connector = scheduler.get_kv_connector()
|
||||||
|
scheduler_connector.set_xfer_handshake_metadata(metadata)
|
||||||
|
|
||||||
|
# Simulate a request that finishes prefill, which returns
|
||||||
|
# corresponding NixlConnectorMetadata for decode instance.
|
||||||
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||||
|
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||||
|
|
||||||
|
request = create_request(
|
||||||
|
request_id=1,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
do_remote_decode=True,
|
||||||
|
)
|
||||||
|
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||||
|
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
|
||||||
|
request, [0, 1, 2]
|
||||||
|
)
|
||||||
|
assert delay
|
||||||
|
|
||||||
|
# Decode connector will be able to create handshake with the prefill connector.
|
||||||
|
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
|
|
||||||
|
# Here we are testing the retrieval of NIXLAgentMetadata.
|
||||||
|
# Knowing the implementation detail, we override the add_remote_agent
|
||||||
|
# to validate the metadata received is the same as the one in prefill_connector.
|
||||||
|
with patch.object(
|
||||||
|
decode_connector.connector_worker, "add_remote_agent"
|
||||||
|
) as mock_add_remote_agent:
|
||||||
|
mock_add_remote_agent.return_type = "remote_agent"
|
||||||
|
|
||||||
|
decode_connector.connector_worker._nixl_handshake(
|
||||||
|
kv_connector_metadata["remote_host"],
|
||||||
|
kv_connector_metadata["remote_port"],
|
||||||
|
kv_connector_metadata["tp_size"],
|
||||||
|
kv_connector_metadata["remote_engine_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
received_metadata = mock_add_remote_agent.call_args.args
|
||||||
|
assert received_metadata[1] == 0 # remote_tp_rank
|
||||||
|
assert received_metadata[2] == 1 # remote_tp_size
|
||||||
|
assert metadata[0] == received_metadata[0]
|
||||||
|
|
||||||
|
# Need to shutdown the background thread to release NIXL side channel port
|
||||||
|
scheduler_connector.shutdown()
|
||||||
|
|
||||||
|
|
||||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||||
REMOTE_ENGINE_ID = "remote_engine"
|
REMOTE_ENGINE_ID = "remote_engine"
|
||||||
|
|
||||||
@@ -313,6 +400,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
|||||||
engine_id=self.REMOTE_ENGINE_ID,
|
engine_id=self.REMOTE_ENGINE_ID,
|
||||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||||
kv_caches_base_addr=[0],
|
kv_caches_base_addr=[0],
|
||||||
|
device_id=0,
|
||||||
num_blocks=1,
|
num_blocks=1,
|
||||||
block_lens=self.block_len_per_layer,
|
block_lens=self.block_len_per_layer,
|
||||||
attn_backend_name=self.backend_name,
|
attn_backend_name=self.backend_name,
|
||||||
@@ -559,6 +647,7 @@ class TestNixlHandshake:
|
|||||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||||
kv_caches_base_addr=[0],
|
kv_caches_base_addr=[0],
|
||||||
|
device_id=0,
|
||||||
num_blocks=1,
|
num_blocks=1,
|
||||||
block_lens=worker.block_len_per_layer,
|
block_lens=worker.block_len_per_layer,
|
||||||
attn_backend_name=worker.backend_name,
|
attn_backend_name=worker.backend_name,
|
||||||
@@ -611,6 +700,7 @@ class TestNixlHandshake:
|
|||||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||||
kv_caches_base_addr=[0],
|
kv_caches_base_addr=[0],
|
||||||
|
device_id=0,
|
||||||
num_blocks=1,
|
num_blocks=1,
|
||||||
# prefill TP=1, decode TP=2, remote block_lens is double to local
|
# prefill TP=1, decode TP=2, remote block_lens is double to local
|
||||||
block_lens=[i * 2 for i in worker.block_len_per_layer],
|
block_lens=[i * 2 for i in worker.block_len_per_layer],
|
||||||
@@ -1005,6 +1095,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
|
|||||||
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
|
_ = llm.generate([f"What is the capital of France? {padding}"], sampling_params)
|
||||||
# Request-0 times out and is cleared!
|
# Request-0 times out and is cleared!
|
||||||
assert "0" not in req_to_blocks
|
assert "0" not in req_to_blocks
|
||||||
|
# Need to shutdown the background thread to release NIXL side channel port
|
||||||
|
llm.llm_engine.engine_core.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def test_register_kv_caches(dist_init):
|
def test_register_kv_caches(dist_init):
|
||||||
@@ -1177,13 +1269,15 @@ def test_shutdown_cleans_up_resources(dist_init):
|
|||||||
"""Test that shutdown() properly cleans up all resources."""
|
"""Test that shutdown() properly cleans up all resources."""
|
||||||
vllm_config = create_vllm_config()
|
vllm_config = create_vllm_config()
|
||||||
|
|
||||||
|
scheduler = NixlConnectorScheduler(
|
||||||
|
vllm_config, vllm_config.kv_transfer_config.engine_id
|
||||||
|
)
|
||||||
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
|
worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id)
|
||||||
nixl_wrapper = worker.nixl_wrapper
|
nixl_wrapper = worker.nixl_wrapper
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
|
patch.object(worker, "_handshake_initiation_executor") as mock_exec,
|
||||||
patch.object(worker, "_nixl_handshake_listener_t") as mock_listener,
|
patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener,
|
||||||
patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event,
|
|
||||||
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
|
patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer,
|
||||||
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
|
patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist,
|
||||||
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
|
patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent,
|
||||||
@@ -1204,8 +1298,12 @@ def test_shutdown_cleans_up_resources(dist_init):
|
|||||||
worker.shutdown()
|
worker.shutdown()
|
||||||
|
|
||||||
mock_exec.shutdown.assert_called_with(wait=False)
|
mock_exec.shutdown.assert_called_with(wait=False)
|
||||||
mock_event.set.assert_called_once()
|
|
||||||
mock_listener.join.assert_called_once_with(timeout=1.0)
|
# Same sequence on scheduler.shutdown()
|
||||||
|
scheduler.shutdown()
|
||||||
|
scheduler.shutdown()
|
||||||
|
scheduler.shutdown()
|
||||||
|
mock_listener.join.assert_called_once()
|
||||||
|
|
||||||
mock_rel_xfer.assert_called_once_with(123)
|
mock_rel_xfer.assert_called_once_with(123)
|
||||||
assert mock_rel_dlist.call_count == 2
|
assert mock_rel_dlist.call_count == 2
|
||||||
|
|||||||
@@ -122,6 +122,15 @@ class KVConnectorRole(enum.Enum):
|
|||||||
WORKER = 1
|
WORKER = 1
|
||||||
|
|
||||||
|
|
||||||
|
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
|
||||||
|
"""
|
||||||
|
Metadata used for out of band connector handshake between
|
||||||
|
P/D workers. This needs to serializeable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class KVConnectorMetadata(ABC): # noqa: B024
|
class KVConnectorMetadata(ABC): # noqa: B024
|
||||||
"""
|
"""
|
||||||
Abstract Metadata used to communicate between the
|
Abstract Metadata used to communicate between the
|
||||||
@@ -320,6 +329,18 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||||
|
"""
|
||||||
|
Get the KVConnector handshake metadata for this connector.
|
||||||
|
This metadata is used for out-of-band connector handshake
|
||||||
|
between P/D workers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||||
|
None if no handshake metadata is available.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Scheduler-side methods
|
# Scheduler-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
@@ -477,6 +498,17 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def set_xfer_handshake_metadata(
|
||||||
|
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set the KV connector handshake metadata for this connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_prom_metrics(
|
def build_prom_metrics(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
CopyBlocksOp,
|
CopyBlocksOp,
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
|
KVConnectorHandshakeMetadata,
|
||||||
KVConnectorMetadata,
|
KVConnectorMetadata,
|
||||||
KVConnectorRole,
|
KVConnectorRole,
|
||||||
)
|
)
|
||||||
@@ -93,15 +94,12 @@ _NIXL_SUPPORTED_DEVICE = {
|
|||||||
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
|
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
|
||||||
|
|
||||||
|
|
||||||
class NixlAgentMetadata(
|
@dataclass
|
||||||
msgspec.Struct,
|
class NixlAgentMetadata(KVConnectorHandshakeMetadata):
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
|
||||||
# required for @cached_property.
|
|
||||||
dict=True,
|
|
||||||
):
|
|
||||||
engine_id: str
|
engine_id: str
|
||||||
agent_metadata: bytes
|
agent_metadata: bytes
|
||||||
kv_caches_base_addr: list[int]
|
kv_caches_base_addr: list[int]
|
||||||
|
device_id: int
|
||||||
num_blocks: int
|
num_blocks: int
|
||||||
block_lens: list[int]
|
block_lens: list[int]
|
||||||
attn_backend_name: str
|
attn_backend_name: str
|
||||||
@@ -223,6 +221,18 @@ class NixlConnector(KVConnectorBase_V1):
|
|||||||
assert self.connector_scheduler is not None
|
assert self.connector_scheduler is not None
|
||||||
return self.connector_scheduler.request_finished(request, block_ids)
|
return self.connector_scheduler.request_finished(request, block_ids)
|
||||||
|
|
||||||
|
def set_xfer_handshake_metadata(
|
||||||
|
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set the KV connector handshake metadata for this connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (dict): the handshake metadata to set.
|
||||||
|
"""
|
||||||
|
assert self.connector_scheduler is not None
|
||||||
|
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# Worker Side Methods
|
# Worker Side Methods
|
||||||
############################################################
|
############################################################
|
||||||
@@ -299,6 +309,21 @@ class NixlConnector(KVConnectorBase_V1):
|
|||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
if self.connector_worker is not None:
|
if self.connector_worker is not None:
|
||||||
self.connector_worker.shutdown()
|
self.connector_worker.shutdown()
|
||||||
|
if self.connector_scheduler is not None:
|
||||||
|
self.connector_scheduler.shutdown()
|
||||||
|
|
||||||
|
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||||
|
"""
|
||||||
|
Get the KVConnector handshake metadata for this connector.
|
||||||
|
This metadata is used for out-of-band connector handshake
|
||||||
|
between P/D workers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KVConnectorHandshakeMetadata: the handshake metadata.
|
||||||
|
None if no handshake metadata is available.
|
||||||
|
"""
|
||||||
|
assert self.connector_worker is not None
|
||||||
|
return self.connector_worker.xfer_handshake_metadata
|
||||||
|
|
||||||
|
|
||||||
class NixlConnectorScheduler:
|
class NixlConnectorScheduler:
|
||||||
@@ -312,12 +337,16 @@ class NixlConnectorScheduler:
|
|||||||
self.side_channel_port = (
|
self.side_channel_port = (
|
||||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
|
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
|
||||||
+ vllm_config.parallel_config.data_parallel_rank
|
+ vllm_config.parallel_config.data_parallel_rank
|
||||||
* vllm_config.parallel_config.tensor_parallel_size
|
|
||||||
)
|
)
|
||||||
assert vllm_config.kv_transfer_config is not None
|
assert vllm_config.kv_transfer_config is not None
|
||||||
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
|
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
|
||||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||||
|
|
||||||
|
# Background thread for handling new handshake requests.
|
||||||
|
self._nixl_handshake_listener_t: threading.Thread | None = None
|
||||||
|
self._encoded_xfer_handshake_metadata: dict[int, Any] = {}
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
|
||||||
# Requests that need to start recv/send.
|
# Requests that need to start recv/send.
|
||||||
# New requests are added by update_state_after_alloc in
|
# New requests are added by update_state_after_alloc in
|
||||||
# the scheduler. Used to make metadata passed to Worker.
|
# the scheduler. Used to make metadata passed to Worker.
|
||||||
@@ -330,6 +359,89 @@ class NixlConnectorScheduler:
|
|||||||
# remote prefill or aborted.
|
# remote prefill or aborted.
|
||||||
self._reqs_not_processed: set[ReqId] = set()
|
self._reqs_not_processed: set[ReqId] = set()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self._stop_event.set()
|
||||||
|
if self._nixl_handshake_listener_t is not None:
|
||||||
|
self._nixl_handshake_listener_t.join()
|
||||||
|
self._nixl_handshake_listener_t = None
|
||||||
|
|
||||||
|
def set_xfer_handshake_metadata(
|
||||||
|
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set the KV connector handshake metadata for this connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (dict): the handshake metadata to set.
|
||||||
|
"""
|
||||||
|
encoded_data: dict[int, bytes] = {}
|
||||||
|
encoder = msgspec.msgpack.Encoder()
|
||||||
|
for tp_rank, rank_metadata in metadata.items():
|
||||||
|
if not isinstance(rank_metadata, NixlAgentMetadata):
|
||||||
|
raise ValueError(
|
||||||
|
"NixlConnectorScheduler expects NixlAgentMetadata for "
|
||||||
|
"handshake metadata."
|
||||||
|
)
|
||||||
|
encoded_data[tp_rank] = encoder.encode(rank_metadata)
|
||||||
|
logger.debug(
|
||||||
|
"Tp rank %d: encoded NixlAgentMetadata size: %s bytes",
|
||||||
|
tp_rank,
|
||||||
|
str(len(encoded_data[tp_rank])),
|
||||||
|
)
|
||||||
|
self._encoded_xfer_handshake_metadata = encoded_data
|
||||||
|
|
||||||
|
# Only start the listener when we have metadata to serve.
|
||||||
|
if self._nixl_handshake_listener_t is None:
|
||||||
|
ready_event = threading.Event()
|
||||||
|
self._nixl_handshake_listener_t = threading.Thread(
|
||||||
|
target=self._nixl_handshake_listener,
|
||||||
|
args=(
|
||||||
|
encoded_data,
|
||||||
|
ready_event,
|
||||||
|
self._stop_event,
|
||||||
|
self.side_channel_port,
|
||||||
|
),
|
||||||
|
daemon=True,
|
||||||
|
name="nixl_handshake_listener",
|
||||||
|
)
|
||||||
|
self._nixl_handshake_listener_t.start()
|
||||||
|
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _nixl_handshake_listener(
|
||||||
|
encoded_data: dict[int, Any],
|
||||||
|
ready_event: threading.Event,
|
||||||
|
stop_event: threading.Event,
|
||||||
|
port: int,
|
||||||
|
):
|
||||||
|
"""Background thread for getting new NIXL handshakes."""
|
||||||
|
# NOTE(rob): this is a simple implementation. We will move
|
||||||
|
# to a better approach via HTTP endpoint soon.
|
||||||
|
|
||||||
|
# Listen for new requests for metadata.
|
||||||
|
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
||||||
|
path = make_zmq_path("tcp", host, port)
|
||||||
|
logger.debug("Starting listening on path: %s", path)
|
||||||
|
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||||
|
sock.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
|
ready_event.set()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
identity, _, msg = sock.recv_multipart()
|
||||||
|
except zmq.Again:
|
||||||
|
if stop_event.is_set():
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
# Decode the message which contains (GET_META_MSG, rank)
|
||||||
|
msg, target_tp_rank = msgspec.msgpack.decode(msg)
|
||||||
|
logger.debug(
|
||||||
|
"Received message for tp rank %s",
|
||||||
|
target_tp_rank,
|
||||||
|
)
|
||||||
|
if msg != GET_META_MSG:
|
||||||
|
logger.warning("Connection listener got unexpected message %s", msg)
|
||||||
|
sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(
|
||||||
self, request: "Request", num_computed_tokens: int
|
self, request: "Request", num_computed_tokens: int
|
||||||
) -> tuple[int, bool]:
|
) -> tuple[int, bool]:
|
||||||
@@ -537,8 +649,6 @@ class NixlConnectorScheduler:
|
|||||||
class NixlConnectorWorker:
|
class NixlConnectorWorker:
|
||||||
"""Implementation of Worker side methods"""
|
"""Implementation of Worker side methods"""
|
||||||
|
|
||||||
_POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TpKVTopology:
|
class TpKVTopology:
|
||||||
"""
|
"""
|
||||||
@@ -651,16 +761,6 @@ class NixlConnectorWorker:
|
|||||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||||
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
|
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
|
||||||
|
|
||||||
# NIXL handshake port.
|
|
||||||
# NOTE(rob): Within a DP group, each DP rank gets its own
|
|
||||||
# base port (which is sent in the KVTransferParams).
|
|
||||||
# Each TP rank listens/queries on the base_port + tp_rank.
|
|
||||||
self.side_channel_port: int = (
|
|
||||||
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
|
|
||||||
+ vllm_config.parallel_config.data_parallel_rank
|
|
||||||
* vllm_config.parallel_config.tensor_parallel_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Metadata.
|
# Metadata.
|
||||||
self.engine_id: EngineId = engine_id
|
self.engine_id: EngineId = engine_id
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
@@ -706,6 +806,7 @@ class NixlConnectorWorker:
|
|||||||
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
||||||
# rank will still only pull from a single remote TP worker.
|
# rank will still only pull from a single remote TP worker.
|
||||||
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
|
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
|
||||||
|
self.device_id: int = 0
|
||||||
|
|
||||||
# Number of NIXL regions. Currently one region per cache
|
# Number of NIXL regions. Currently one region per cache
|
||||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||||
@@ -736,9 +837,8 @@ class NixlConnectorWorker:
|
|||||||
# requests that skipped transfer (handshake or transfer failures)
|
# requests that skipped transfer (handshake or transfer failures)
|
||||||
self._failed_recv_reqs: set[ReqId] = set()
|
self._failed_recv_reqs: set[ReqId] = set()
|
||||||
|
|
||||||
# Background thread for handling new handshake requests.
|
# Handshake metadata of this worker for NIXL transfers.
|
||||||
self._nixl_handshake_listener_t: threading.Thread | None = None
|
self.xfer_handshake_metadata: NixlAgentMetadata | None = None
|
||||||
self._nixl_handshake_listener_stop_event: threading.Event | None = None
|
|
||||||
# Background thread for initializing new NIXL handshakes.
|
# Background thread for initializing new NIXL handshakes.
|
||||||
self._handshake_initiation_executor = ThreadPoolExecutor(
|
self._handshake_initiation_executor = ThreadPoolExecutor(
|
||||||
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
|
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
|
||||||
@@ -790,42 +890,6 @@ class NixlConnectorWorker:
|
|||||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _nixl_handshake_listener(
|
|
||||||
metadata: NixlAgentMetadata,
|
|
||||||
ready_event: threading.Event,
|
|
||||||
stop_event: threading.Event,
|
|
||||||
base_port: int,
|
|
||||||
tp_rank: int,
|
|
||||||
):
|
|
||||||
"""Background thread for getting new NIXL handshakes."""
|
|
||||||
# NOTE(rob): this is a simple implementation. We will move
|
|
||||||
# to a better approach via HTTP endpoint soon.
|
|
||||||
|
|
||||||
encoder = msgspec.msgpack.Encoder()
|
|
||||||
encoded_data = encoder.encode(metadata)
|
|
||||||
size_in_bytes = len(encoded_data)
|
|
||||||
logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes))
|
|
||||||
|
|
||||||
# Listen for new requests for metadata.
|
|
||||||
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
|
|
||||||
path = make_zmq_path("tcp", host, base_port + tp_rank)
|
|
||||||
logger.debug("Starting listening on path: %s", path)
|
|
||||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
|
||||||
ready_event.set()
|
|
||||||
poller = zmq.Poller()
|
|
||||||
poller.register(sock, zmq.POLLIN)
|
|
||||||
while not stop_event.is_set():
|
|
||||||
events = dict(
|
|
||||||
poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000)
|
|
||||||
)
|
|
||||||
if sock not in events:
|
|
||||||
continue
|
|
||||||
identity, _, msg = sock.recv_multipart()
|
|
||||||
if msg != GET_META_MSG:
|
|
||||||
logger.warning("Connection listener got unexpected message %s", msg)
|
|
||||||
sock.send_multipart((identity, b"", encoded_data))
|
|
||||||
|
|
||||||
def _nixl_handshake(
|
def _nixl_handshake(
|
||||||
self,
|
self,
|
||||||
host: str,
|
host: str,
|
||||||
@@ -844,16 +908,17 @@ class NixlConnectorWorker:
|
|||||||
# Handshake only with the remote TP rank that current local rank will
|
# Handshake only with the remote TP rank that current local rank will
|
||||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||||
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
|
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
|
||||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
path = make_zmq_path("tcp", host, port)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
|
"Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send query for the request.
|
# Send query for the request.
|
||||||
with zmq_ctx(zmq.REQ, path) as sock:
|
with zmq_ctx(zmq.REQ, path) as sock:
|
||||||
|
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank))
|
||||||
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
||||||
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
||||||
sock.send(GET_META_MSG)
|
sock.send(msg)
|
||||||
metadata_bytes = sock.recv()
|
metadata_bytes = sock.recv()
|
||||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||||
metadata = decoder.decode(metadata_bytes)
|
metadata = decoder.decode(metadata_bytes)
|
||||||
@@ -1042,6 +1107,10 @@ class NixlConnectorWorker:
|
|||||||
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
||||||
"All kv cache tensors must have the same size"
|
"All kv cache tensors must have the same size"
|
||||||
)
|
)
|
||||||
|
# Need to make sure the device ID is non-negative for NIXL,
|
||||||
|
# Torch uses -1 to indicate CPU tensors while NIXL uses explicit
|
||||||
|
# memory type.
|
||||||
|
self.device_id = max(cache.get_device(), 0)
|
||||||
caches_data.append(
|
caches_data.append(
|
||||||
(base_addr, curr_tensor_size_bytes, self.device_id, "")
|
(base_addr, curr_tensor_size_bytes, self.device_id, "")
|
||||||
)
|
)
|
||||||
@@ -1139,10 +1208,11 @@ class NixlConnectorWorker:
|
|||||||
assert len(self.block_window_per_layer) == self.num_layers
|
assert len(self.block_window_per_layer) == self.num_layers
|
||||||
|
|
||||||
# After KV Caches registered, listen for new connections.
|
# After KV Caches registered, listen for new connections.
|
||||||
metadata = NixlAgentMetadata(
|
self.xfer_handshake_metadata = NixlAgentMetadata(
|
||||||
engine_id=self.engine_id,
|
engine_id=self.engine_id,
|
||||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||||
|
device_id=self.device_id,
|
||||||
num_blocks=self.num_blocks,
|
num_blocks=self.num_blocks,
|
||||||
block_lens=self.block_len_per_layer,
|
block_lens=self.block_len_per_layer,
|
||||||
attn_backend_name=self.backend_name,
|
attn_backend_name=self.backend_name,
|
||||||
@@ -1150,22 +1220,6 @@ class NixlConnectorWorker:
|
|||||||
if not self.use_host_buffer
|
if not self.use_host_buffer
|
||||||
else self.host_buffer_kv_cache_layout,
|
else self.host_buffer_kv_cache_layout,
|
||||||
)
|
)
|
||||||
ready_event, stop_event = threading.Event(), threading.Event()
|
|
||||||
self._nixl_handshake_listener_t = threading.Thread(
|
|
||||||
target=self._nixl_handshake_listener,
|
|
||||||
args=(
|
|
||||||
metadata,
|
|
||||||
ready_event,
|
|
||||||
stop_event,
|
|
||||||
self.side_channel_port,
|
|
||||||
self.tp_rank,
|
|
||||||
),
|
|
||||||
daemon=True,
|
|
||||||
name="nixl_handshake_listener",
|
|
||||||
)
|
|
||||||
self._nixl_handshake_listener_t.start()
|
|
||||||
self._nixl_handshake_listener_stop_event = stop_event
|
|
||||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
|
||||||
|
|
||||||
def add_remote_agent(
|
def add_remote_agent(
|
||||||
self,
|
self,
|
||||||
@@ -1267,7 +1321,7 @@ class NixlConnectorWorker:
|
|||||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||||
addr = base_addr + block_offset + rank_offset
|
addr = base_addr + block_offset + rank_offset
|
||||||
# (addr, len, device id)
|
# (addr, len, device id)
|
||||||
blocks_data.append((addr, kv_block_len, remote_tp_rank))
|
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
|
||||||
|
|
||||||
if self._use_flashinfer:
|
if self._use_flashinfer:
|
||||||
# With FlashInfer index V separately to allow head splitting.
|
# With FlashInfer index V separately to allow head splitting.
|
||||||
@@ -1275,7 +1329,9 @@ class NixlConnectorWorker:
|
|||||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||||
addr = base_addr + block_offset + rank_offset
|
addr = base_addr + block_offset + rank_offset
|
||||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||||
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
|
blocks_data.append(
|
||||||
|
(v_addr, kv_block_len, nixl_agent_meta.device_id)
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
|
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
|
||||||
@@ -1843,14 +1899,6 @@ class NixlConnectorWorker:
|
|||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Shutdown the connector worker."""
|
"""Shutdown the connector worker."""
|
||||||
self._handshake_initiation_executor.shutdown(wait=False)
|
self._handshake_initiation_executor.shutdown(wait=False)
|
||||||
if self._nixl_handshake_listener_stop_event is not None:
|
|
||||||
self._nixl_handshake_listener_stop_event.set()
|
|
||||||
self._nixl_handshake_listener_stop_event = None
|
|
||||||
if self._nixl_handshake_listener_t is not None:
|
|
||||||
# Generous timeout to allow the thread to exit
|
|
||||||
self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10)
|
|
||||||
assert not self._nixl_handshake_listener_t.is_alive()
|
|
||||||
self._nixl_handshake_listener_t = None
|
|
||||||
for handles in self._recving_transfers.values():
|
for handles in self._recving_transfers.values():
|
||||||
for handle, _ in handles:
|
for handle, _ in handles:
|
||||||
self.nixl_wrapper.release_xfer_handle(handle)
|
self.nixl_wrapper.release_xfer_handle(handle)
|
||||||
|
|||||||
@@ -163,6 +163,27 @@ class EngineCore:
|
|||||||
vllm_config, mm_registry
|
vllm_config, mm_registry
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If a KV connector is initialized for scheduler, we want to collect
|
||||||
|
# handshake metadata from all workers so the connector in the scheduler
|
||||||
|
# will have the full context
|
||||||
|
kv_connector = self.scheduler.get_kv_connector()
|
||||||
|
if kv_connector is not None:
|
||||||
|
# Collect and store KV connector xfer metadata from workers
|
||||||
|
# (after KV cache registration)
|
||||||
|
xfer_handshake_metadata = (
|
||||||
|
self.model_executor.get_kv_connector_handshake_metadata()
|
||||||
|
)
|
||||||
|
|
||||||
|
if xfer_handshake_metadata:
|
||||||
|
# xfer_handshake_metadata is list of dicts from workers
|
||||||
|
# Each dict already has structure {tp_rank: metadata}
|
||||||
|
# Merge all worker dicts into a single dict
|
||||||
|
content: dict[int, Any] = {}
|
||||||
|
for worker_dict in xfer_handshake_metadata:
|
||||||
|
if worker_dict is not None:
|
||||||
|
content.update(worker_dict)
|
||||||
|
kv_connector.set_xfer_handshake_metadata(content)
|
||||||
|
|
||||||
# Setup batch queue for pipeline parallelism.
|
# Setup batch queue for pipeline parallelism.
|
||||||
# Batch queue for scheduled batches. This enables us to asynchronously
|
# Batch queue for scheduled batches. This enables us to asynchronously
|
||||||
# schedule and execute batches, and is required by pipeline parallelism
|
# schedule and execute batches, and is required by pipeline parallelism
|
||||||
@@ -178,7 +199,7 @@ class EngineCore:
|
|||||||
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
|
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
|
||||||
if (
|
if (
|
||||||
self.vllm_config.cache_config.enable_prefix_caching
|
self.vllm_config.cache_config.enable_prefix_caching
|
||||||
or self.scheduler.get_kv_connector() is not None
|
or kv_connector is not None
|
||||||
):
|
):
|
||||||
caching_hash_fn = get_hash_fn_by_name(
|
caching_hash_fn = get_hash_fn_by_name(
|
||||||
vllm_config.cache_config.prefix_caching_hash_algo
|
vllm_config.cache_config.prefix_caching_hash_algo
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ from typing import TYPE_CHECKING, Literal, TypeVar, overload
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorHandshakeMetadata,
|
||||||
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
@@ -177,6 +180,11 @@ class Executor(ABC):
|
|||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_kv_connector_handshake_metadata(
|
||||||
|
self,
|
||||||
|
) -> list[dict[int, KVConnectorHandshakeMetadata]]:
|
||||||
|
return self.collective_rpc("get_kv_connector_handshake_metadata")
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -19,7 +19,11 @@ from vllm.distributed import (
|
|||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce,
|
set_custom_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
from vllm.distributed.kv_transfer import (
|
||||||
|
ensure_kv_transfer_initialized,
|
||||||
|
get_kv_transfer_group,
|
||||||
|
has_kv_transfer_group,
|
||||||
|
)
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
@@ -348,6 +352,21 @@ class Worker(WorkerBase):
|
|||||||
|
|
||||||
return int(self.available_kv_cache_memory_bytes)
|
return int(self.available_kv_cache_memory_bytes)
|
||||||
|
|
||||||
|
def get_kv_connector_handshake_metadata(self) -> dict | None:
|
||||||
|
"""Get KV connector metadata from this worker if available."""
|
||||||
|
|
||||||
|
if not has_kv_transfer_group():
|
||||||
|
return None
|
||||||
|
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
# Return None for connectors that don't need to exchange handshake
|
||||||
|
# metadata across workers.
|
||||||
|
if (metadata := connector.get_handshake_metadata()) is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tp_rank = get_tp_group().rank_in_group
|
||||||
|
return {tp_rank: metadata}
|
||||||
|
|
||||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||||
return self.model_runner.get_kv_cache_spec()
|
return self.model_runner.get_kv_cache_spec()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user