Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -12,20 +12,23 @@ from typing import Callable, Generator, Literal
import grpc
import pytest
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
ExportTraceServiceResponse)
ExportTraceServiceResponse,
)
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
TraceServiceServicer, add_TraceServiceServicer_to_server)
TraceServiceServicer,
add_TraceServiceServicer_to_server,
)
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_TRACES_INSECURE)
from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_TRACES_INSECURE
from vllm import LLM, SamplingParams
from vllm.tracing import SpanAttributes
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
'array_value']
FieldName = Literal[
"bool_value", "string_value", "int_value", "double_value", "array_value"
]
def decode_value(value: AnyValue):
@@ -34,8 +37,9 @@ def decode_value(value: AnyValue):
"string_value": (lambda v: v.string_value),
"int_value": (lambda v: v.int_value),
"double_value": (lambda v: v.double_value),
"array_value":
(lambda v: [decode_value(item) for item in v.array_value.values]),
"array_value": (
lambda v: [decode_value(item) for item in v.array_value.values]
),
}
for field, decoder in field_decoders.items():
if value.HasField(field):
@@ -48,7 +52,6 @@ def decode_attributes(attributes: Iterable[KeyValue]):
class FakeTraceService(TraceServiceServicer):
def __init__(self):
self.request = None
self.evt = threading.Event()
@@ -86,10 +89,12 @@ def test_traces(
max_tokens=256,
)
model = "facebook/opt-125m"
llm = LLM(model=model,
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
gpu_memory_utilization=0.3,
disable_log_stats=False)
llm = LLM(
model=model,
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
gpu_memory_utilization=0.3,
disable_log_stats=False,
)
prompts = ["This is a short prompt"]
outputs = llm.generate(prompts, sampling_params=sampling_params)
print(f"test_traces outputs is : {outputs}")
@@ -98,40 +103,48 @@ def test_traces(
if not trace_service.evt.wait(timeout):
raise TimeoutError(
f"The fake trace service didn't receive a trace within "
f"the {timeout} seconds timeout")
f"the {timeout} seconds timeout"
)
request = trace_service.request
assert len(request.resource_spans) == 1, (
f"Expected 1 resource span, "
f"but got {len(request.resource_spans)}")
f"Expected 1 resource span, but got {len(request.resource_spans)}"
)
assert len(request.resource_spans[0].scope_spans) == 1, (
f"Expected 1 scope span, "
f"but got {len(request.resource_spans[0].scope_spans)}")
f"but got {len(request.resource_spans[0].scope_spans)}"
)
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
f"Expected 1 span, "
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}")
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}"
)
attributes = decode_attributes(
request.resource_spans[0].scope_spans[0].spans[0].attributes)
request.resource_spans[0].scope_spans[0].spans[0].attributes
)
# assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
) == sampling_params.temperature
assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS
) == sampling_params.max_tokens
assert attributes.get(
SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
assert attributes.get(
SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids)
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
assert (
attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE)
== sampling_params.temperature
)
assert (
attributes.get(SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
)
assert (
attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS)
== sampling_params.max_tokens
)
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids
)
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
assert attributes.get(
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
assert (
attributes.get(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS)
== completion_tokens
)
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) > 0
assert attributes.get(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) > 0