Attempt to resolve #34284: "Metrics Tracing (2GPU)" fails with a segmentation fault. Signed-off-by: emricksini-h <emrick.birivoutin@hcompany.ai>
151 lines
5.3 KiB
Python
151 lines
5.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import threading
|
|
from collections.abc import Callable, Generator, Iterable
|
|
from concurrent import futures
|
|
from typing import Any, Literal
|
|
|
|
import grpc
|
|
import pytest
|
|
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
|
ExportTraceServiceRequest,
|
|
ExportTraceServiceResponse,
|
|
)
|
|
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
|
|
TraceServiceServicer,
|
|
add_TraceServiceServicer_to_server,
|
|
)
|
|
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
|
|
|
|
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
|
|
|
|
FieldName = Literal[
|
|
"bool_value", "string_value", "int_value", "double_value", "array_value"
|
|
]
|
|
|
|
|
|
def decode_value(value: AnyValue):
|
|
"""Decode an OpenTelemetry AnyValue protobuf message to a Python value."""
|
|
field_decoders: dict[FieldName, Callable] = {
|
|
"bool_value": (lambda v: v.bool_value),
|
|
"string_value": (lambda v: v.string_value),
|
|
"int_value": (lambda v: v.int_value),
|
|
"double_value": (lambda v: v.double_value),
|
|
"array_value": (
|
|
lambda v: [decode_value(item) for item in v.array_value.values]
|
|
),
|
|
}
|
|
for field, decoder in field_decoders.items():
|
|
if value.HasField(field):
|
|
return decoder(value)
|
|
raise ValueError(f"Couldn't decode value: {value}")
|
|
|
|
|
|
def decode_attributes(attributes: Iterable[KeyValue]) -> dict[str, Any]:
|
|
"""Decode OpenTelemetry KeyValue attributes to a Python dictionary."""
|
|
return {kv.key: decode_value(kv.value) for kv in attributes}
|
|
|
|
|
|
class FakeTraceService(TraceServiceServicer):
|
|
"""A fake gRPC trace service for testing OpenTelemetry trace exports."""
|
|
|
|
def __init__(self):
|
|
self.requests: list[ExportTraceServiceRequest] = []
|
|
self.evt = threading.Event()
|
|
self._lock = threading.Lock()
|
|
|
|
def Export(self, request, context):
|
|
with self._lock:
|
|
self.requests.append(request)
|
|
self.evt.set()
|
|
return ExportTraceServiceResponse()
|
|
|
|
@property
|
|
def request(self) -> ExportTraceServiceRequest | None:
|
|
"""Returns the first request received (for backward compatibility)."""
|
|
with self._lock:
|
|
return self.requests[0] if self.requests else None
|
|
|
|
def get_all_spans(self) -> list[dict]:
|
|
"""Returns all spans from all received requests as decoded dicts."""
|
|
spans = []
|
|
with self._lock:
|
|
for request in self.requests:
|
|
for resource_span in request.resource_spans:
|
|
for scope_span in resource_span.scope_spans:
|
|
for span in scope_span.spans:
|
|
spans.append(
|
|
{
|
|
"name": span.name,
|
|
"attributes": decode_attributes(span.attributes),
|
|
"trace_id": span.trace_id.hex(),
|
|
"span_id": span.span_id.hex(),
|
|
"parent_span_id": span.parent_span_id.hex()
|
|
if span.parent_span_id
|
|
else None,
|
|
"start_time_unix_nano": span.start_time_unix_nano,
|
|
"end_time_unix_nano": span.end_time_unix_nano,
|
|
}
|
|
)
|
|
return spans
|
|
|
|
def wait_for_spans(self, count: int = 1, timeout: float = 10) -> bool:
|
|
"""Wait until at least `count` spans have been received."""
|
|
import time
|
|
|
|
deadline = time.time() + timeout
|
|
while time.time() < deadline:
|
|
if len(self.get_all_spans()) >= count:
|
|
return True
|
|
time.sleep(0.1)
|
|
return False
|
|
|
|
def clear(self):
|
|
"""Clear all received requests."""
|
|
with self._lock:
|
|
self.requests.clear()
|
|
self.evt.clear()
|
|
|
|
|
|
def _wait_for_server_ready(address: str, timeout: float = 5.0) -> bool:
|
|
"""Wait for the gRPC server to be ready to accept connections."""
|
|
import socket
|
|
import time
|
|
|
|
host, port = address.rsplit(":", 1)
|
|
deadline = time.monotonic() + timeout
|
|
while time.monotonic() < deadline:
|
|
try:
|
|
with socket.create_connection((host, int(port)), timeout=0.5):
|
|
return True
|
|
except (OSError, ConnectionRefusedError):
|
|
time.sleep(0.1)
|
|
return False
|
|
|
|
|
|
@pytest.fixture
|
|
def trace_service() -> Generator[FakeTraceService, None, None]:
|
|
"""Fixture to set up a fake gRPC trace service."""
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
|
|
service = FakeTraceService()
|
|
add_TraceServiceServicer_to_server(service, server)
|
|
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
|
|
server.start()
|
|
|
|
# Wait for the server to be ready to accept connections
|
|
if not _wait_for_server_ready(FAKE_TRACE_SERVER_ADDRESS):
|
|
server.stop(grace=None)
|
|
raise RuntimeError(
|
|
f"Fake trace server failed to start on {FAKE_TRACE_SERVER_ADDRESS}"
|
|
)
|
|
|
|
yield service
|
|
|
|
server.stop(grace=None)
|
|
|
|
|
|
@pytest.fixture
|
|
def trace_server_address() -> str:
|
|
"""Returns the address of the fake trace server."""
|
|
return FAKE_TRACE_SERVER_ADDRESS
|