[NIXL][Bugfix] Failure logging overhaul + early metadata free on failure (#32031)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2026-01-12 21:38:49 +01:00
committed by GitHub
parent ca81811bfe
commit f8bd8394e3
2 changed files with 234 additions and 28 deletions

View File

@@ -1705,6 +1705,8 @@ class FailingNixlWrapper(FakeNixlWrapper):
self.fail_handshake = False
self.fail_transfer_setup = False
self.fail_send_notif = False
self.fail_transfer_state = False # Returns "ERR" state
self.fail_transfer_exception = False # Raises exception in check_xfer_state
def add_remote_agent(self, agent_metadata: bytes) -> str:
if self.fail_handshake:
@@ -1739,6 +1741,150 @@ class FailingNixlWrapper(FakeNixlWrapper):
raise RuntimeError("Simulated send_notif failure")
return super().send_notif(agent_name, notif_msg)
def check_xfer_state(self, handle: int) -> str:
if self.fail_transfer_exception:
raise RuntimeError("Simulated check_xfer_state exception")
if self.fail_transfer_state:
return "ERR" # Bad transfer state
return super().check_xfer_state(handle)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper,
)
@pytest.mark.parametrize(
"failure_type,wrapper_config,needs_get_finished",
[
("transfer_setup_failed", {"fail_transfer_setup": True}, False),
("handshake_failed", {"fail_handshake": True}, False),
("notification_failed", {"fail_send_notif": True}, False),
("transfer_failed", {"fail_transfer_state": True}, True),
("transfer_exception", {"fail_transfer_exception": True}, True),
],
)
def test_transfer_failure_logging(
default_vllm_config,
dist_init,
failure_type,
wrapper_config,
needs_get_finished,
):
"""Test that transfer failures are logged with structured context.
Run with `pytest -sv` to see the log output.
Covers failure types:
- transfer_setup_failed: make_prepped_xfer fails
- handshake_failed: add_remote_agent fails during request handshake
- notification_failed: send_notif fails
- transfer_failed: check_xfer_state returns bad state (e.g., "ERR")
- transfer_exception: check_xfer_state raises exception
"""
import logging
vllm_config = create_vllm_config()
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0.0
)
# Configure FailingNixlWrapper to fail in the specified way
for key, value in wrapper_config.items():
setattr(connector.connector_worker.nixl_wrapper, key, value)
request_id = f"test_{failure_type}_req"
# For notification_failed, we need empty local blocks
# (full cache hit path to trigger send_notif)
local_blocks = [] if failure_type == "notification_failed" else [10, 11, 12]
remote_blocks = [20, 21, 22]
metadata = NixlConnectorMetadata()
metadata.add_new_req_to_recv(
request_id=request_id,
local_block_ids=local_blocks,
kv_transfer_params={
"remote_block_ids": remote_blocks,
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
},
)
connector.bind_connector_metadata(metadata)
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
# Capture logs from the nixl_connector logger specifically
# vLLM loggers have propagate=False, so we need to capture directly
nixl_logger = logging.getLogger(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
)
captured_logs: list[logging.LogRecord] = []
class LogCapture(logging.Handler):
def emit(self, record):
captured_logs.append(record)
handler = LogCapture()
handler.setLevel(logging.ERROR)
nixl_logger.addHandler(handler)
try:
connector.start_load_kv(dummy_ctx)
# Process the ready_requests queue (for async handshake)
connector.bind_connector_metadata(NixlConnectorMetadata())
# Wait for async handshake to complete
time.sleep(0.2)
connector.start_load_kv(dummy_ctx)
# For transfer_failed/transfer_exception, the error happens in
# get_finished() when checking transfer state
if needs_get_finished:
connector.get_finished(finished_req_ids=set())
finally:
nixl_logger.removeHandler(handler)
# Print logs for manual comparison between commits
error_logs = [r for r in captured_logs if r.levelno >= logging.ERROR]
print("\n" + "=" * 60)
print(f"CAPTURED ERROR LOGS for {failure_type}:")
print("=" * 60)
for i, record in enumerate(error_logs):
print(f"\n--- Log {i + 1} ---")
print(f"Message: {record.message}")
print("=" * 60 + "\n")
assert len(error_logs) >= 1, f"Expected at least one error log for {failure_type}"
# Verify structured logging output (new format)
# Check that at least one log matches the expected format
all_messages = [r.message for r in error_logs]
combined_logs = "\n".join(all_messages)
assert any("NIXL transfer failure" in msg for msg in all_messages), (
f"Expected structured log format with 'NIXL transfer failure' prefix "
f"for {failure_type}. Got: {all_messages}"
)
assert any("failure_type" in msg for msg in all_messages), (
f"Expected 'failure_type' in logs. Got: {all_messages}"
)
assert any("Context:" in msg for msg in all_messages), (
f"Expected 'Context:' in logs. Got: {all_messages}"
)
# Check that the expected failure_type appears in at least one log
# Note: handshake_failed also triggers handshake_setup_failed
assert failure_type in combined_logs or (
failure_type == "handshake_failed" and "handshake_setup_failed" in combined_logs
), f"Expected '{failure_type}' in logs. Got: {all_messages}"
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",

View File

@@ -1152,6 +1152,50 @@ class NixlConnectorWorker:
assert self.use_host_buffer
self.copy_blocks = copy_operation
def _log_failure(
self,
failure_type: str,
req_id: str | None,
msg: str = "",
error: Exception | None = None,
meta: ReqMeta | None = None,
**extra_context,
):
"""Log transfer failure with structured context for easier debugging."""
context: dict[str, Any] = {
"failure_type": failure_type,
"request_id": req_id,
"engine_id": self.engine_id,
}
if meta is None and req_id is not None:
# Try to get metadata from in progress transfers when not provided
meta = self._recving_metadata.get(req_id)
if meta and meta.remote:
context.update(
{
"remote_engine_id": meta.remote.engine_id,
"remote_request_id": meta.remote.request_id,
"remote_host": meta.remote.host,
"remote_port": meta.remote.port,
"num_local_blocks": len(meta.local_block_ids),
"num_remote_blocks": len(meta.remote.block_ids),
"local_block_ids_sample": meta.local_block_ids[:10],
}
)
context.update(extra_context)
if msg:
failure_type = f"{failure_type}. {msg}"
logger.error(
"NIXL transfer failure: %s | Context: %s",
failure_type,
context,
exc_info=error is not None,
stacklevel=2,
)
def _background_nixl_handshake(
self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta
):
@@ -1173,8 +1217,13 @@ class NixlConnectorWorker:
del self._handshake_futures[eid]
try:
self._remote_agents[eid] = f.result()
except Exception:
logger.exception("Handshake with %s failed", eid)
except Exception as e:
self._log_failure(
failure_type="handshake_setup_failed",
req_id=None,
error=e,
remote_engine_id=eid,
)
fut.add_done_callback(done_callback)
@@ -1184,10 +1233,13 @@ class NixlConnectorWorker:
# check if handshake succeeded
f.result()
self._ready_requests.put(entry)
except Exception:
except Exception as e:
# handshake failed - mark blocks as invalid
logger.exception(
"Handshake failed for request %s, marking blocks as invalid", req_id
self._log_failure(
failure_type="handshake_failed",
req_id=req_id,
error=e,
meta=meta,
)
if req_meta := self._recving_metadata.get(req_id):
self._invalid_block_ids.update(req_meta.local_block_ids)
@@ -1941,18 +1993,19 @@ class NixlConnectorWorker:
in_progress.append(handle)
continue
else:
logger.error(
"NIXL transfer failed for request %s with state "
"%s. Marking blocks as invalid.",
req_id,
xfer_state,
self._log_failure(
failure_type="transfer_failed",
msg="Marking blocks as invalid",
req_id=req_id,
xfer_state=xfer_state,
)
self._handle_failed_transfer(req_id, handle)
except Exception:
logger.exception(
"NIXL transfer exception for request %s. "
"Marking blocks as invalid.",
req_id,
except Exception as e:
self._log_failure(
failure_type="transfer_exception",
msg="Marking blocks as invalid",
req_id=req_id,
error=e,
)
self._handle_failed_transfer(req_id, handle)
@@ -1973,9 +2026,9 @@ class NixlConnectorWorker:
req_id: The request ID.
handle: The transfer handle.
"""
if meta := self._recving_metadata.pop(req_id, None):
# Use .get() here as the metadata cleanup is handled by get_finished()
if meta := self._recving_metadata.get(req_id):
self._invalid_block_ids.update(meta.local_block_ids)
self._recving_metadata.pop(req_id, None)
self.nixl_wrapper.release_xfer_handle(handle)
self.xfer_stats.record_failed_transfer()
@@ -2150,12 +2203,16 @@ class NixlConnectorWorker:
agent_name = self._remote_agents[dst_engine_id][remote_rank]
try:
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
except Exception:
logger.exception(
"NIXL send_notif failed for request %s: "
"P worker blocks will be freed after timeout. "
except Exception as e:
self._log_failure(
failure_type="notification_failed",
msg="P worker blocks will be freed after timeout. "
"This may indicate network issues.",
request_id,
req_id=request_id,
error=e,
dst_engine_id=dst_engine_id,
remote_rank=remote_rank,
remote_agent_name=agent_name,
)
self.xfer_stats.record_failed_notification()
return
@@ -2240,13 +2297,16 @@ class NixlConnectorWorker:
# Use handle to check completion in future step().
self._recving_transfers[request_id].append(handle)
except Exception:
logger.exception(
"NIXL transfer setup/initiation failed for request %s. "
"Marking blocks as invalid.",
request_id,
)
except Exception as e:
# mark all (logical) blocks for this request as invalid
self._log_failure(
failure_type="transfer_setup_failed",
req_id=request_id,
msg="Marking blocks as invalid",
error=e,
dst_engine_id=dst_engine_id,
remote_rank=remote_rank,
)
if meta := self._recving_metadata.get(request_id):
self._invalid_block_ids.update(meta.local_block_ids)
self.xfer_stats.record_failed_transfer()