[NIXL][Bugfix] Failure logging overhaul + early metadata free on failure (#32031)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -1705,6 +1705,8 @@ class FailingNixlWrapper(FakeNixlWrapper):
|
||||
self.fail_handshake = False
|
||||
self.fail_transfer_setup = False
|
||||
self.fail_send_notif = False
|
||||
self.fail_transfer_state = False # Returns "ERR" state
|
||||
self.fail_transfer_exception = False # Raises exception in check_xfer_state
|
||||
|
||||
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
||||
if self.fail_handshake:
|
||||
@@ -1739,6 +1741,150 @@ class FailingNixlWrapper(FakeNixlWrapper):
|
||||
raise RuntimeError("Simulated send_notif failure")
|
||||
return super().send_notif(agent_name, notif_msg)
|
||||
|
||||
def check_xfer_state(self, handle: int) -> str:
|
||||
if self.fail_transfer_exception:
|
||||
raise RuntimeError("Simulated check_xfer_state exception")
|
||||
if self.fail_transfer_state:
|
||||
return "ERR" # Bad transfer state
|
||||
return super().check_xfer_state(handle)
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FailingNixlWrapper,
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"failure_type,wrapper_config,needs_get_finished",
|
||||
[
|
||||
("transfer_setup_failed", {"fail_transfer_setup": True}, False),
|
||||
("handshake_failed", {"fail_handshake": True}, False),
|
||||
("notification_failed", {"fail_send_notif": True}, False),
|
||||
("transfer_failed", {"fail_transfer_state": True}, True),
|
||||
("transfer_exception", {"fail_transfer_exception": True}, True),
|
||||
],
|
||||
)
|
||||
def test_transfer_failure_logging(
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
failure_type,
|
||||
wrapper_config,
|
||||
needs_get_finished,
|
||||
):
|
||||
"""Test that transfer failures are logged with structured context.
|
||||
|
||||
Run with `pytest -sv` to see the log output.
|
||||
|
||||
Covers failure types:
|
||||
- transfer_setup_failed: make_prepped_xfer fails
|
||||
- handshake_failed: add_remote_agent fails during request handshake
|
||||
- notification_failed: send_notif fails
|
||||
- transfer_failed: check_xfer_state returns bad state (e.g., "ERR")
|
||||
- transfer_exception: check_xfer_state raises exception
|
||||
"""
|
||||
import logging
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0.0
|
||||
)
|
||||
|
||||
# Configure FailingNixlWrapper to fail in the specified way
|
||||
for key, value in wrapper_config.items():
|
||||
setattr(connector.connector_worker.nixl_wrapper, key, value)
|
||||
|
||||
request_id = f"test_{failure_type}_req"
|
||||
|
||||
# For notification_failed, we need empty local blocks
|
||||
# (full cache hit path to trigger send_notif)
|
||||
local_blocks = [] if failure_type == "notification_failed" else [10, 11, 12]
|
||||
remote_blocks = [20, 21, 22]
|
||||
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req_to_recv(
|
||||
request_id=request_id,
|
||||
local_block_ids=local_blocks,
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": remote_blocks,
|
||||
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_request_id": f"prefill-{request_id}",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
},
|
||||
)
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
dummy_ctx = ForwardContext(
|
||||
no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0,
|
||||
)
|
||||
|
||||
# Capture logs from the nixl_connector logger specifically
|
||||
# vLLM loggers have propagate=False, so we need to capture directly
|
||||
nixl_logger = logging.getLogger(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
|
||||
)
|
||||
captured_logs: list[logging.LogRecord] = []
|
||||
|
||||
class LogCapture(logging.Handler):
|
||||
def emit(self, record):
|
||||
captured_logs.append(record)
|
||||
|
||||
handler = LogCapture()
|
||||
handler.setLevel(logging.ERROR)
|
||||
nixl_logger.addHandler(handler)
|
||||
|
||||
try:
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
# Process the ready_requests queue (for async handshake)
|
||||
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||
# Wait for async handshake to complete
|
||||
time.sleep(0.2)
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
|
||||
# For transfer_failed/transfer_exception, the error happens in
|
||||
# get_finished() when checking transfer state
|
||||
if needs_get_finished:
|
||||
connector.get_finished(finished_req_ids=set())
|
||||
finally:
|
||||
nixl_logger.removeHandler(handler)
|
||||
|
||||
# Print logs for manual comparison between commits
|
||||
error_logs = [r for r in captured_logs if r.levelno >= logging.ERROR]
|
||||
print("\n" + "=" * 60)
|
||||
print(f"CAPTURED ERROR LOGS for {failure_type}:")
|
||||
print("=" * 60)
|
||||
for i, record in enumerate(error_logs):
|
||||
print(f"\n--- Log {i + 1} ---")
|
||||
print(f"Message: {record.message}")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
assert len(error_logs) >= 1, f"Expected at least one error log for {failure_type}"
|
||||
|
||||
# Verify structured logging output (new format)
|
||||
# Check that at least one log matches the expected format
|
||||
all_messages = [r.message for r in error_logs]
|
||||
combined_logs = "\n".join(all_messages)
|
||||
|
||||
assert any("NIXL transfer failure" in msg for msg in all_messages), (
|
||||
f"Expected structured log format with 'NIXL transfer failure' prefix "
|
||||
f"for {failure_type}. Got: {all_messages}"
|
||||
)
|
||||
assert any("failure_type" in msg for msg in all_messages), (
|
||||
f"Expected 'failure_type' in logs. Got: {all_messages}"
|
||||
)
|
||||
assert any("Context:" in msg for msg in all_messages), (
|
||||
f"Expected 'Context:' in logs. Got: {all_messages}"
|
||||
)
|
||||
# Check that the expected failure_type appears in at least one log
|
||||
# Note: handshake_failed also triggers handshake_setup_failed
|
||||
assert failure_type in combined_logs or (
|
||||
failure_type == "handshake_failed" and "handshake_setup_failed" in combined_logs
|
||||
), f"Expected '{failure_type}' in logs. Got: {all_messages}"
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
|
||||
@@ -1152,6 +1152,50 @@ class NixlConnectorWorker:
|
||||
assert self.use_host_buffer
|
||||
self.copy_blocks = copy_operation
|
||||
|
||||
def _log_failure(
|
||||
self,
|
||||
failure_type: str,
|
||||
req_id: str | None,
|
||||
msg: str = "",
|
||||
error: Exception | None = None,
|
||||
meta: ReqMeta | None = None,
|
||||
**extra_context,
|
||||
):
|
||||
"""Log transfer failure with structured context for easier debugging."""
|
||||
context: dict[str, Any] = {
|
||||
"failure_type": failure_type,
|
||||
"request_id": req_id,
|
||||
"engine_id": self.engine_id,
|
||||
}
|
||||
if meta is None and req_id is not None:
|
||||
# Try to get metadata from in progress transfers when not provided
|
||||
meta = self._recving_metadata.get(req_id)
|
||||
|
||||
if meta and meta.remote:
|
||||
context.update(
|
||||
{
|
||||
"remote_engine_id": meta.remote.engine_id,
|
||||
"remote_request_id": meta.remote.request_id,
|
||||
"remote_host": meta.remote.host,
|
||||
"remote_port": meta.remote.port,
|
||||
"num_local_blocks": len(meta.local_block_ids),
|
||||
"num_remote_blocks": len(meta.remote.block_ids),
|
||||
"local_block_ids_sample": meta.local_block_ids[:10],
|
||||
}
|
||||
)
|
||||
|
||||
context.update(extra_context)
|
||||
if msg:
|
||||
failure_type = f"{failure_type}. {msg}"
|
||||
|
||||
logger.error(
|
||||
"NIXL transfer failure: %s | Context: %s",
|
||||
failure_type,
|
||||
context,
|
||||
exc_info=error is not None,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def _background_nixl_handshake(
|
||||
self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta
|
||||
):
|
||||
@@ -1173,8 +1217,13 @@ class NixlConnectorWorker:
|
||||
del self._handshake_futures[eid]
|
||||
try:
|
||||
self._remote_agents[eid] = f.result()
|
||||
except Exception:
|
||||
logger.exception("Handshake with %s failed", eid)
|
||||
except Exception as e:
|
||||
self._log_failure(
|
||||
failure_type="handshake_setup_failed",
|
||||
req_id=None,
|
||||
error=e,
|
||||
remote_engine_id=eid,
|
||||
)
|
||||
|
||||
fut.add_done_callback(done_callback)
|
||||
|
||||
@@ -1184,10 +1233,13 @@ class NixlConnectorWorker:
|
||||
# check if handshake succeeded
|
||||
f.result()
|
||||
self._ready_requests.put(entry)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# handshake failed - mark blocks as invalid
|
||||
logger.exception(
|
||||
"Handshake failed for request %s, marking blocks as invalid", req_id
|
||||
self._log_failure(
|
||||
failure_type="handshake_failed",
|
||||
req_id=req_id,
|
||||
error=e,
|
||||
meta=meta,
|
||||
)
|
||||
if req_meta := self._recving_metadata.get(req_id):
|
||||
self._invalid_block_ids.update(req_meta.local_block_ids)
|
||||
@@ -1941,18 +1993,19 @@ class NixlConnectorWorker:
|
||||
in_progress.append(handle)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
"NIXL transfer failed for request %s with state "
|
||||
"%s. Marking blocks as invalid.",
|
||||
req_id,
|
||||
xfer_state,
|
||||
self._log_failure(
|
||||
failure_type="transfer_failed",
|
||||
msg="Marking blocks as invalid",
|
||||
req_id=req_id,
|
||||
xfer_state=xfer_state,
|
||||
)
|
||||
self._handle_failed_transfer(req_id, handle)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"NIXL transfer exception for request %s. "
|
||||
"Marking blocks as invalid.",
|
||||
req_id,
|
||||
except Exception as e:
|
||||
self._log_failure(
|
||||
failure_type="transfer_exception",
|
||||
msg="Marking blocks as invalid",
|
||||
req_id=req_id,
|
||||
error=e,
|
||||
)
|
||||
self._handle_failed_transfer(req_id, handle)
|
||||
|
||||
@@ -1973,9 +2026,9 @@ class NixlConnectorWorker:
|
||||
req_id: The request ID.
|
||||
handle: The transfer handle.
|
||||
"""
|
||||
if meta := self._recving_metadata.pop(req_id, None):
|
||||
# Use .get() here as the metadata cleanup is handled by get_finished()
|
||||
if meta := self._recving_metadata.get(req_id):
|
||||
self._invalid_block_ids.update(meta.local_block_ids)
|
||||
self._recving_metadata.pop(req_id, None)
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
self.xfer_stats.record_failed_transfer()
|
||||
|
||||
@@ -2150,12 +2203,16 @@ class NixlConnectorWorker:
|
||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||
try:
|
||||
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"NIXL send_notif failed for request %s: "
|
||||
"P worker blocks will be freed after timeout. "
|
||||
except Exception as e:
|
||||
self._log_failure(
|
||||
failure_type="notification_failed",
|
||||
msg="P worker blocks will be freed after timeout. "
|
||||
"This may indicate network issues.",
|
||||
request_id,
|
||||
req_id=request_id,
|
||||
error=e,
|
||||
dst_engine_id=dst_engine_id,
|
||||
remote_rank=remote_rank,
|
||||
remote_agent_name=agent_name,
|
||||
)
|
||||
self.xfer_stats.record_failed_notification()
|
||||
return
|
||||
@@ -2240,13 +2297,16 @@ class NixlConnectorWorker:
|
||||
|
||||
# Use handle to check completion in future step().
|
||||
self._recving_transfers[request_id].append(handle)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"NIXL transfer setup/initiation failed for request %s. "
|
||||
"Marking blocks as invalid.",
|
||||
request_id,
|
||||
)
|
||||
except Exception as e:
|
||||
# mark all (logical) blocks for this request as invalid
|
||||
self._log_failure(
|
||||
failure_type="transfer_setup_failed",
|
||||
req_id=request_id,
|
||||
msg="Marking blocks as invalid",
|
||||
error=e,
|
||||
dst_engine_id=dst_engine_id,
|
||||
remote_rank=remote_rank,
|
||||
)
|
||||
if meta := self._recving_metadata.get(request_id):
|
||||
self._invalid_block_ids.update(meta.local_block_ids)
|
||||
self.xfer_stats.record_failed_transfer()
|
||||
|
||||
Reference in New Issue
Block a user