diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 0a73e2a78..71f5d4b2b 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -190,7 +190,6 @@ def _make_fake_nixl_pkg(): # Copy of FakeNixlWrapper implementation for Ray workers import uuid from collections import defaultdict -from typing import Optional {fake_nixl_source} @@ -1143,3 +1142,145 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init): # After abort, the worker should not keep tracking it as "in-batch" assert req.request_id not in connector.connector_worker._reqs_to_process #### Model Runner end #### + + +class FailingNixlWrapper(FakeNixlWrapper): + """Mock NixlWrapper that fails on specific operations.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fail_handshake = False + self.fail_transfer_setup = False + self.fail_send_notif = False + + def add_remote_agent(self, agent_metadata: bytes) -> str: + if self.fail_handshake: + from zmq.error import Again + + raise Again("Simulated timeout failure") + return super().add_remote_agent(agent_metadata) + + def make_prepped_xfer( + self, + xfer_type: str, + local_xfer_side_handle: int, + local_block_descs_ids: list[int], + remote_xfer_side_handle: int, + remote_block_descs_ids: list[int], + notif_msg: bytes | None = None, + ) -> int: + if self.fail_transfer_setup: + # classic RuntimeError to simulate failure + raise RuntimeError("BAD STATUS") + return super().make_prepped_xfer( + xfer_type, + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg, + ) + + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: + if self.fail_send_notif: + raise RuntimeError("Simulated send_notif failure") + return super().send_notif(agent_name, notif_msg) + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FailingNixlWrapper, +) +def test_handshake_failure_returns_finished(dist_init): + """Test that handshake failures mark blocks invalid and return via get_finished.""" + vllm_config = create_vllm_config() + + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0.1 + ) + connector.connector_worker.nixl_wrapper.fail_handshake = True + + request_id = "test_handshake_fail" + metadata = NixlConnectorMetadata() + metadata.add_new_req( + request_id=request_id, + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }, + ) + connector.bind_connector_metadata(metadata) + + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + connector.start_load_kv(dummy_ctx) + + # Wait for handshake to fail + time.sleep(0.3) + + # Check that blocks were marked invalid + invalid_blocks = connector.get_block_ids_with_load_errors() + assert invalid_blocks == {1, 2, 3} + + # Check that request appears in get_finished + _, done_recving = connector.get_finished(finished_req_ids=set()) + assert request_id in done_recving + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FailingNixlWrapper, +) +def test_transfer_setup_failure_returns_finished(dist_init): + """Test that transfer setup failures mark blocks invalid + and return via get_finished.""" + vllm_config = create_vllm_config() + + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + connector.connector_worker.nixl_wrapper.fail_transfer_setup = True + + request_id = "test_transfer_fail" + metadata = NixlConnectorMetadata() + metadata.add_new_req( + request_id=request_id, + local_block_ids=[7, 8, 9], + kv_transfer_params={ + "remote_block_ids": [10, 11, 12], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }, + ) + connector.bind_connector_metadata(metadata) + + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + connector.start_load_kv(dummy_ctx) + + # Wait for handshake to complete and process ready_requests + connector.bind_connector_metadata(NixlConnectorMetadata()) + time.sleep(0.1) + connector.start_load_kv(dummy_ctx) + + # check that blocks were marked invalid + invalid_blocks = connector.get_block_ids_with_load_errors() + assert invalid_blocks == {7, 8, 9} + + # ensure request appears in get_finished + _, done_recving = connector.get_finished(finished_req_ids=set()) + assert request_id in done_recving diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a8730bf78..c1b4e50e7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -68,6 +68,7 @@ except ImportError: NixlWrapper = None nixlXferTelemetry = None + try: from nixl._api import nixl_agent_config except ImportError: @@ -234,6 +235,11 @@ class NixlConnector(KVConnectorBase_V1): assert self.connector_worker is not None return self.connector_worker.get_finished() + def get_block_ids_with_load_errors(self) -> set[int]: + """Get block IDs that failed to load via NIXL.""" + assert self.connector_worker is not None + return self.connector_worker.get_block_ids_with_load_errors() + def get_kv_connector_stats(self) -> KVConnectorStats | None: assert self.connector_worker is not None return self.connector_worker.get_kv_connector_stats() @@ -614,6 +620,11 @@ class NixlConnectorWorker: # Set of requests that have been part of a batch, regardless of status. self._reqs_to_process: set[ReqId] = set() + # invalid blocks from failed NIXL operations + self._invalid_block_ids: set[int] = set() + # requests that skipped transfer (handshake or transfer failures) + self._failed_recv_reqs: set[ReqId] = set() + # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: threading.Thread | None = None # Background thread for initializing new NIXL handshakes. @@ -713,6 +724,8 @@ class NixlConnectorWorker: # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: + # Set receive timeout to 5 seconds to avoid hanging on dead server + sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds sock.send(GET_META_MSG) metadata_bytes = sock.recv() decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) @@ -795,10 +808,20 @@ class NixlConnectorWorker: fut.add_done_callback(done_callback) - # TODO: handle failure state of future in the - # callback, we want to fail the request in this case. - def request_ready(_f: Future[Any], entry=(req_id, meta)): - self._ready_requests.put(entry) + # check handshake success before proceeding with request + def request_ready(f: Future[Any], entry=(req_id, meta)): + try: + # check if handshake succeeded + f.result() + self._ready_requests.put(entry) + except Exception: + # handshake failed - mark blocks as invalid + logger.exception( + "Handshake failed for request %s, marking blocks as invalid", req_id + ) + if req_meta := self._recving_metadata.get(req_id): + self._invalid_block_ids.update(req_meta.local_block_ids) + self._failed_recv_reqs.add(req_id) fut.add_done_callback(request_ready) @@ -1205,6 +1228,11 @@ class NixlConnectorWorker: """ done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) + + # add requests that skipped transfer to done_recving + done_recving.update(self._failed_recv_reqs) + self._failed_recv_reqs.clear() + if len(done_sending) > 0 or len(done_recving) > 0: logger.debug( "Rank %s, get_finished: %s requests done sending " @@ -1214,10 +1242,10 @@ class NixlConnectorWorker: len(done_recving), ) - if self.use_host_buffer: - for req_id in done_recving: - meta = self._recving_metadata.pop(req_id) - assert meta, f"{req_id} not found in recving_metadata list" + # clean up metadata for completed requests + for req_id in done_recving: + meta = self._recving_metadata.pop(req_id, None) + if self.use_host_buffer and meta: self.sync_recved_kv_to_device(req_id, meta) # Handle timeout to avoid stranding blocks on remote. @@ -1296,7 +1324,19 @@ class NixlConnectorWorker: in_progress = True continue else: - raise RuntimeError("Transfer failed with state %s", xfer_state) + # transfer failed - mark blocks as invalid + logger.error( + "NIXL transfer failed for request %s with state %s. " + "Marking blocks as invalid.", + req_id, + xfer_state, + ) + # mark all blocks for this request as invalid + if meta := self._recving_metadata.pop(req_id, None): + self._invalid_block_ids.update(meta.local_block_ids) + self._recving_metadata.pop(req_id, None) + self.nixl_wrapper.release_xfer_handle(handle) + self.xfer_stats.record_failed_transfer() if not in_progress: done_req_ids.add(req_id) del transfers[req_id] @@ -1317,8 +1357,8 @@ class NixlConnectorWorker: len(meta.local_block_ids), len(meta.remote_block_ids), ) - if self.use_host_buffer: - self._recving_metadata[req_id] = meta + # always store metadata for failure recovery + self._recving_metadata[req_id] = meta if remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: @@ -1394,7 +1434,16 @@ class NixlConnectorWorker: if num_local_blocks == 0: remote_rank = self.tp_rank // tp_ratio agent_name = self._remote_agents[dst_engine_id][remote_rank] - self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) + try: + self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) + except Exception: + logger.exception( + "NIXL send_notif failed for request %s: " + "P worker blocks will be freed after timeout. " + "This may indicate network issues.", + request_id, + ) + self.xfer_stats.record_failed_notification() return # Partial prefix cache hit: just read uncomputed blocks. @@ -1456,20 +1505,35 @@ class NixlConnectorWorker: assert len(local_block_descs_ids) == len(remote_block_descs_ids) # Prepare transfer with Nixl. - handle = self.nixl_wrapper.make_prepped_xfer( - "READ", - local_xfer_side_handle, - local_block_descs_ids, - remote_xfer_side_handle, - remote_block_descs_ids, - notif_msg=notif_id, - ) + handle = None + try: + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg=notif_id, + ) - # Begin async xfer. - self.nixl_wrapper.transfer(handle) + # Begin async xfer. + self.nixl_wrapper.transfer(handle) - # Use handle to check completion in future step(). - self._recving_transfers[request_id].append((handle, time.perf_counter())) + # Use handle to check completion in future step(). + self._recving_transfers[request_id].append((handle, time.perf_counter())) + except Exception: + logger.exception( + "NIXL transfer setup/initiation failed for request %s. " + "Marking blocks as invalid.", + request_id, + ) + # mark all blocks for this request as invalid + if meta := self._recving_metadata.get(request_id): + self._invalid_block_ids.update(meta.local_block_ids) + self.xfer_stats.record_failed_transfer() + if handle is not None: + self.nixl_wrapper.release_xfer_handle(handle) + self._failed_recv_reqs.add(request_id) def _get_block_descs_ids( self, engine_id: str, block_ids: list[int], layer_idx: int | None = None @@ -1527,6 +1591,17 @@ class NixlConnectorWorker: return self.xfer_stats.clone_and_reset() return None + def get_block_ids_with_load_errors(self) -> set[int]: + """ + Return and clear the set of block IDs that failed to load. + + This is called by the scheduler to identify blocks that need + to be retried after a NIXL transfer failure. + """ + result = self._invalid_block_ids + self._invalid_block_ids = set() + return result + def shutdown(self): """Shutdown the connector worker.""" self._handshake_initiation_executor.shutdown(wait=False) @@ -1586,6 +1661,8 @@ class NixlKVConnectorStats(KVConnectorStats): "post_duration": [], "bytes_transferred": [], "num_descriptors": [], + "num_failed_transfers": [], + "num_failed_notifications": [], } def record_transfer(self, res: nixlXferTelemetry): @@ -1595,6 +1672,14 @@ class NixlKVConnectorStats(KVConnectorStats): self.data["bytes_transferred"].append(res.totalBytes) self.data["num_descriptors"].append(res.descCount) + def record_failed_transfer(self): + """Record a failed NIXL transfer operation.""" + self.data["num_failed_transfers"].append(1.0) + + def record_failed_notification(self): + """Record a failed NIXL notification (send_notif).""" + self.data["num_failed_notifications"].append(1.0) + def clone_and_reset(self) -> "NixlKVConnectorStats": old = copy.copy(self) self.reset() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 168084177..cbbdf48c6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1487,7 +1487,7 @@ class Scheduler(SchedulerInterface): total_tokens_to_reschedule += num_tokens_to_reschedule # Mark requests with async KV load failures; they will be rescheduled - # once loading completes + # once loading completes. self.failed_recving_kv_req_ids |= async_affected_req_ids # --- Handle sync KV loads (running requests) ---