[Misc] Add OpenTelemetry support (#4687)

This PR adds basic support for OpenTelemetry distributed tracing.
It includes changes to enable tracing functionality and improve monitoring capabilities.

I've also added a markdown with print-screens to guide users how to use this feature. You can find it here
This commit is contained in:
Ronen Schaffer
2024-06-18 19:17:03 +03:00
committed by GitHub
parent 13db4369d9
commit 7879f24dcc
15 changed files with 567 additions and 41 deletions

View File

@@ -1,14 +1,14 @@
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
@@ -31,6 +31,8 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@@ -154,6 +156,7 @@ class LLMEngine:
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@@ -168,7 +171,8 @@ class LLMEngine:
"disable_custom_all_reduce=%s, quantization=%s, "
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)",
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
@@ -192,6 +196,7 @@ class LLMEngine:
model_config.quantization_param_path,
device_config.device,
decoding_config,
observability_config,
model_config.seed,
model_config.served_model_name,
)
@@ -207,6 +212,8 @@ class LLMEngine:
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
if not self.model_config.skip_tokenizer_init:
@@ -288,6 +295,12 @@ class LLMEngine:
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)
self.tracer = None
if self.observability_config.otlp_traces_endpoint:
self.tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
@@ -444,6 +457,7 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Dict[str, str]] = None,
) -> None:
# Create the sequences.
block_size = self.cache_config.block_size
@@ -461,6 +475,7 @@ class LLMEngine:
params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
@@ -507,6 +522,7 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
) -> None:
"""Add a request to the engine's request pool.
@@ -524,6 +540,7 @@ class LLMEngine:
:class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
trace_headers: OpenTelemetry trace headers.
Details:
- Set arrival_time to the current time if it is None.
@@ -565,6 +582,7 @@ class LLMEngine:
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
)
def _create_sequence_group_with_sampling(
@@ -574,6 +592,7 @@ class LLMEngine:
sampling_params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Dict[str, str]] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
@@ -595,11 +614,14 @@ class LLMEngine:
self.generation_config_fields)
# Create the sequence group.
seq_group = SequenceGroup(request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request)
seq_group = SequenceGroup(
request_id=request_id,
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
return seq_group
@@ -793,6 +815,9 @@ class LLMEngine:
# Log stats.
self.do_log_stats(scheduler_outputs, output)
# Tracing
self.do_tracing(scheduler_outputs)
if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
@@ -986,3 +1011,62 @@ class LLMEngine:
def check_health(self) -> None:
self.model_executor.check_health()
def is_tracing_enabled(self) -> bool:
return self.tracer is not None
def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None:
if self.tracer is None:
return
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
self.create_trace_span(seq_group)
def create_trace_span(self, seq_group: SequenceGroup) -> None:
if self.tracer is None or seq_group.sampling_params is None:
return
arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9)
trace_context = extract_trace_context(seq_group.trace_headers)
with self.tracer.start_as_current_span(
"llm_request",
kind=SpanKind.SERVER,
context=trace_context,
start_time=arrival_time_nano_seconds) as seq_span:
metrics = seq_group.metrics
ttft = metrics.first_token_time - metrics.arrival_time
e2e_time = metrics.finished_time - metrics.arrival_time
# attribute names are based on
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md
seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL,
self.model_config.model)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID,
seq_group.request_id)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE,
seq_group.sampling_params.temperature)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P,
seq_group.sampling_params.top_p)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
seq_group.sampling_params.max_tokens)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
seq_group.sampling_params.best_of)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
seq_group.sampling_params.n)
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
seq_group.num_seqs())
seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
len(seq_group.prompt_token_ids))
seq_span.set_attribute(
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS,
sum([
seq.get_output_len()
for seq in seq_group.get_finished_seqs()
]))
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE,
metrics.time_in_queue)
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)