[Metrics] Some small refactoring for better maintainability (#33898)

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
This commit is contained in:
Martin Hickey
2026-03-20 16:11:34 +00:00
committed by GitHub
parent c0f5fae601
commit 880be2b1b8
7 changed files with 123 additions and 135 deletions

View File

@@ -126,28 +126,17 @@ class KVConnectorPromMetrics:
self._labelnames = labelnames
self.per_engine_labelvalues = per_engine_labelvalues
def make_per_engine(self, metric: PromMetric) -> dict[int, PromMetric]:
"""
Create a per-engine child of a prometheus_client.Metric with
the appropriate labels set. The parent metric must be created
using the labelnames list.
"""
return {
idx: metric.labels(*labelvalues)
for idx, labelvalues in self.per_engine_labelvalues.items()
}
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
"""
Record the supplied transfer statistics to Prometheus metrics. These
statistics are engine-specific, and should be recorded to a metric
with the appropriate 'engine' label. These metric instances can be
created using the make_per_engine() helper method.
created using the create_metric_per_engine() helper method.
"""
raise NotImplementedError
class KVConnectorPrometheus:
class KVConnectorProm:
"""
Support for registering per-connector Prometheus metrics, and
recording transfer statistics to those metrics. Uses

View File

@@ -65,6 +65,7 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.metrics.utils import create_metric_per_engine
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.utils import select_common_block_size
@@ -3057,7 +3058,9 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets[1:],
labelnames=labelnames,
)
self.nixl_histogram_xfer_time = self.make_per_engine(nixl_histogram_xfer_time)
self.nixl_histogram_xfer_time = create_metric_per_engine(
nixl_histogram_xfer_time, self.per_engine_labelvalues
)
nixl_histogram_post_time = self._histogram_cls(
name="vllm:nixl_post_time_seconds",
documentation="Histogram of transfer post time for NIXL KV"
@@ -3065,7 +3068,9 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_post_time = self.make_per_engine(nixl_histogram_post_time)
self.nixl_histogram_post_time = create_metric_per_engine(
nixl_histogram_post_time, self.per_engine_labelvalues
)
# uniform 2kb to 16gb range
buckets = [2 ** (10 + i) for i in range(1, 25, 2)]
nixl_histogram_bytes_transferred = self._histogram_cls(
@@ -3074,8 +3079,8 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_bytes_transferred = self.make_per_engine(
nixl_histogram_bytes_transferred
self.nixl_histogram_bytes_transferred = create_metric_per_engine(
nixl_histogram_bytes_transferred, self.per_engine_labelvalues
)
buckets = [
10,
@@ -3100,24 +3105,24 @@ class NixlPromMetrics(KVConnectorPromMetrics):
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_num_descriptors = self.make_per_engine(
nixl_histogram_num_descriptors
self.nixl_histogram_num_descriptors = create_metric_per_engine(
nixl_histogram_num_descriptors, self.per_engine_labelvalues
)
counter_nixl_num_failed_transfers = self._counter_cls(
name="vllm:nixl_num_failed_transfers",
documentation="Number of failed NIXL KV Cache transfers.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_transfers = self.make_per_engine(
counter_nixl_num_failed_transfers
self.counter_nixl_num_failed_transfers = create_metric_per_engine(
counter_nixl_num_failed_transfers, self.per_engine_labelvalues
)
counter_nixl_num_failed_notifications = self._counter_cls(
name="vllm:nixl_num_failed_notifications",
documentation="Number of failed NIXL KV Cache notifications.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_notifications = self.make_per_engine(
counter_nixl_num_failed_notifications
self.counter_nixl_num_failed_notifications = create_metric_per_engine(
counter_nixl_num_failed_notifications, self.per_engine_labelvalues
)
counter_nixl_num_kv_expired_reqs = self._counter_cls(
@@ -3126,8 +3131,8 @@ class NixlPromMetrics(KVConnectorPromMetrics):
"NOTE: This metric is tracked on the P instance.",
labelnames=labelnames,
)
self.counter_nixl_num_kv_expired_reqs = self.make_per_engine(
counter_nixl_num_kv_expired_reqs
self.counter_nixl_num_kv_expired_reqs = create_metric_per_engine(
counter_nixl_num_kv_expired_reqs, self.per_engine_labelvalues
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):

View File

@@ -5,7 +5,6 @@ import logging
import time
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TypeAlias
from prometheus_client import Counter, Gauge, Histogram
@@ -14,7 +13,7 @@ from vllm.compilation.cuda_graph import CUDAGraphLogging
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorLogging,
KVConnectorPrometheus,
KVConnectorProm,
)
from vllm.logger import init_logger
from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group
@@ -28,6 +27,7 @@ from vllm.v1.metrics.stats import (
PromptTokenStats,
SchedulerStats,
)
from vllm.v1.metrics.utils import create_metric_per_engine
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
@@ -391,7 +391,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
_counter_cls = Counter
_histogram_cls = Histogram
_spec_decoding_cls = SpecDecodingProm
_kv_connector_cls = KVConnectorPrometheus
_kv_connector_cls = KVConnectorProm
_perf_metrics_cls = PerfMetricsProm
def __init__(
@@ -415,9 +415,10 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
model_name = vllm_config.model_config.served_model_name
max_model_len = vllm_config.model_config.max_model_len
per_engine_labelvalues: dict[int, list[object]] = {
self.per_engine_labelvalues: dict[int, list[object]] = {
idx: [model_name, str(idx)] for idx in engine_indexes
}
per_engine_labelvalues = self.per_engine_labelvalues
self.spec_decoding_prom = self._spec_decoding_cls(
vllm_config.speculative_config, labelnames, per_engine_labelvalues
@@ -438,8 +439,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
multiprocess_mode="mostrecent",
labelnames=labelnames,
)
self.gauge_scheduler_running = make_per_engine(
gauge_scheduler_running, engine_indexes, model_name
self.gauge_scheduler_running = create_metric_per_engine(
gauge_scheduler_running, per_engine_labelvalues
)
gauge_scheduler_waiting = self._gauge_cls(
@@ -448,8 +449,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
multiprocess_mode="mostrecent",
labelnames=labelnames,
)
self.gauge_scheduler_waiting = make_per_engine(
gauge_scheduler_waiting, engine_indexes, model_name
self.gauge_scheduler_waiting = create_metric_per_engine(
gauge_scheduler_waiting, per_engine_labelvalues
)
gauge_engine_sleep_state = self._gauge_cls(
@@ -484,8 +485,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
multiprocess_mode="mostrecent",
labelnames=labelnames,
)
self.gauge_kv_cache_usage = make_per_engine(
gauge_kv_cache_usage, engine_indexes, model_name
self.gauge_kv_cache_usage = create_metric_per_engine(
gauge_kv_cache_usage, per_engine_labelvalues
)
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
@@ -497,8 +498,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_corrupted_requests = make_per_engine(
counter_corrupted_requests, engine_indexes, model_name
self.counter_corrupted_requests = create_metric_per_engine(
counter_corrupted_requests, per_engine_labelvalues
)
counter_prefix_cache_queries = self._counter_cls(
@@ -508,8 +509,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_prefix_cache_queries = make_per_engine(
counter_prefix_cache_queries, engine_indexes, model_name
self.counter_prefix_cache_queries = create_metric_per_engine(
counter_prefix_cache_queries, per_engine_labelvalues
)
counter_prefix_cache_hits = self._counter_cls(
@@ -517,8 +518,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation=("Prefix cache hits, in terms of number of cached tokens."),
labelnames=labelnames,
)
self.counter_prefix_cache_hits = make_per_engine(
counter_prefix_cache_hits, engine_indexes, model_name
self.counter_prefix_cache_hits = create_metric_per_engine(
counter_prefix_cache_hits, per_engine_labelvalues
)
#
@@ -533,8 +534,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_connector_prefix_cache_queries = make_per_engine(
counter_connector_prefix_cache_queries, engine_indexes, model_name
self.counter_connector_prefix_cache_queries = create_metric_per_engine(
counter_connector_prefix_cache_queries, per_engine_labelvalues
)
counter_connector_prefix_cache_hits = self._counter_cls(
@@ -545,8 +546,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_connector_prefix_cache_hits = make_per_engine(
counter_connector_prefix_cache_hits, engine_indexes, model_name
self.counter_connector_prefix_cache_hits = create_metric_per_engine(
counter_connector_prefix_cache_hits, per_engine_labelvalues
)
#
@@ -560,8 +561,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_mm_cache_queries = make_per_engine(
counter_mm_cache_queries, engine_indexes, model_name
self.counter_mm_cache_queries = create_metric_per_engine(
counter_mm_cache_queries, per_engine_labelvalues
)
counter_mm_cache_hits = self._counter_cls(
@@ -571,8 +572,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
),
labelnames=labelnames,
)
self.counter_mm_cache_hits = make_per_engine(
counter_mm_cache_hits, engine_indexes, model_name
self.counter_mm_cache_hits = create_metric_per_engine(
counter_mm_cache_hits, per_engine_labelvalues
)
#
@@ -583,8 +584,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames,
)
self.counter_num_preempted_reqs = make_per_engine(
counter_num_preempted_reqs, engine_indexes, model_name
self.counter_num_preempted_reqs = create_metric_per_engine(
counter_num_preempted_reqs, per_engine_labelvalues
)
counter_prompt_tokens = self._counter_cls(
@@ -592,8 +593,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
)
self.counter_prompt_tokens = make_per_engine(
counter_prompt_tokens, engine_indexes, model_name
self.counter_prompt_tokens = create_metric_per_engine(
counter_prompt_tokens, per_engine_labelvalues
)
# Labeled prompt token counters by source
@@ -617,8 +618,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of cached prompt tokens (local + external).",
labelnames=labelnames,
)
self.counter_prompt_tokens_cached = make_per_engine(
counter_prompt_tokens_cached, engine_indexes, model_name
self.counter_prompt_tokens_cached = create_metric_per_engine(
counter_prompt_tokens_cached, per_engine_labelvalues
)
# Recomputed tokens (last token recomputed when entire prompt is cached)
@@ -627,8 +628,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of cached tokens recomputed for forward pass.",
labelnames=labelnames,
)
self.counter_prompt_tokens_recomputed = make_per_engine(
counter_prompt_tokens_recomputed, engine_indexes, model_name
self.counter_prompt_tokens_recomputed = create_metric_per_engine(
counter_prompt_tokens_recomputed, per_engine_labelvalues
)
counter_generation_tokens = self._counter_cls(
@@ -636,8 +637,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
documentation="Number of generation tokens processed.",
labelnames=labelnames,
)
self.counter_generation_tokens = make_per_engine(
counter_generation_tokens, engine_indexes, model_name
self.counter_generation_tokens = create_metric_per_engine(
counter_generation_tokens, per_engine_labelvalues
)
self.counter_request_success: dict[FinishReason, dict[int, Counter]] = {}
@@ -663,8 +664,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames,
)
self.histogram_num_prompt_tokens_request = make_per_engine(
histogram_num_prompt_tokens_request, engine_indexes, model_name
self.histogram_num_prompt_tokens_request = create_metric_per_engine(
histogram_num_prompt_tokens_request, per_engine_labelvalues
)
histogram_num_generation_tokens_request = self._histogram_cls(
@@ -673,8 +674,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames,
)
self.histogram_num_generation_tokens_request = make_per_engine(
histogram_num_generation_tokens_request, engine_indexes, model_name
self.histogram_num_generation_tokens_request = create_metric_per_engine(
histogram_num_generation_tokens_request, per_engine_labelvalues
)
# TODO: This metric might be incorrect in case of using multiple
@@ -686,8 +687,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
labelnames=labelnames,
)
self.histogram_iteration_tokens = make_per_engine(
histogram_iteration_tokens, engine_indexes, model_name
self.histogram_iteration_tokens = create_metric_per_engine(
histogram_iteration_tokens, per_engine_labelvalues
)
histogram_max_num_generation_tokens_request = self._histogram_cls(
@@ -696,8 +697,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames,
)
self.histogram_max_num_generation_tokens_request = make_per_engine(
histogram_max_num_generation_tokens_request, engine_indexes, model_name
self.histogram_max_num_generation_tokens_request = create_metric_per_engine(
histogram_max_num_generation_tokens_request, per_engine_labelvalues
)
histogram_n_request = self._histogram_cls(
@@ -706,8 +707,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=[1, 2, 5, 10, 20],
labelnames=labelnames,
)
self.histogram_n_request = make_per_engine(
histogram_n_request, engine_indexes, model_name
self.histogram_n_request = create_metric_per_engine(
histogram_n_request, per_engine_labelvalues
)
histogram_max_tokens_request = self._histogram_cls(
@@ -716,8 +717,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames,
)
self.histogram_max_tokens_request = make_per_engine(
histogram_max_tokens_request, engine_indexes, model_name
self.histogram_max_tokens_request = create_metric_per_engine(
histogram_max_tokens_request, per_engine_labelvalues
)
#
@@ -752,8 +753,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
],
labelnames=labelnames,
)
self.histogram_time_to_first_token = make_per_engine(
histogram_time_to_first_token, engine_indexes, model_name
self.histogram_time_to_first_token = create_metric_per_engine(
histogram_time_to_first_token, per_engine_labelvalues
)
histogram_inter_token_latency = self._histogram_cls(
@@ -782,8 +783,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
],
labelnames=labelnames,
)
self.histogram_inter_token_latency = make_per_engine(
histogram_inter_token_latency, engine_indexes, model_name
self.histogram_inter_token_latency = create_metric_per_engine(
histogram_inter_token_latency, per_engine_labelvalues
)
histogram_request_time_per_output_token = self._histogram_cls(
@@ -812,8 +813,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
],
labelnames=labelnames,
)
self.histogram_request_time_per_output_token = make_per_engine(
histogram_request_time_per_output_token, engine_indexes, model_name
self.histogram_request_time_per_output_token = create_metric_per_engine(
histogram_request_time_per_output_token, per_engine_labelvalues
)
request_latency_buckets = [
@@ -845,8 +846,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets,
labelnames=labelnames,
)
self.histogram_e2e_time_request = make_per_engine(
histogram_e2e_time_request, engine_indexes, model_name
self.histogram_e2e_time_request = create_metric_per_engine(
histogram_e2e_time_request, per_engine_labelvalues
)
histogram_queue_time_request = self._histogram_cls(
@@ -855,8 +856,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets,
labelnames=labelnames,
)
self.histogram_queue_time_request = make_per_engine(
histogram_queue_time_request, engine_indexes, model_name
self.histogram_queue_time_request = create_metric_per_engine(
histogram_queue_time_request, per_engine_labelvalues
)
histogram_inference_time_request = self._histogram_cls(
@@ -865,8 +866,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets,
labelnames=labelnames,
)
self.histogram_inference_time_request = make_per_engine(
histogram_inference_time_request, engine_indexes, model_name
self.histogram_inference_time_request = create_metric_per_engine(
histogram_inference_time_request, per_engine_labelvalues
)
histogram_prefill_time_request = self._histogram_cls(
@@ -875,8 +876,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets,
labelnames=labelnames,
)
self.histogram_prefill_time_request = make_per_engine(
histogram_prefill_time_request, engine_indexes, model_name
self.histogram_prefill_time_request = create_metric_per_engine(
histogram_prefill_time_request, per_engine_labelvalues
)
histogram_decode_time_request = self._histogram_cls(
@@ -885,8 +886,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=request_latency_buckets,
labelnames=labelnames,
)
self.histogram_decode_time_request = make_per_engine(
histogram_decode_time_request, engine_indexes, model_name
self.histogram_decode_time_request = create_metric_per_engine(
histogram_decode_time_request, per_engine_labelvalues
)
histogram_prefill_kv_computed_request = self._histogram_cls(
@@ -898,8 +899,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames,
)
self.histogram_prefill_kv_computed_request = make_per_engine(
histogram_prefill_kv_computed_request, engine_indexes, model_name
self.histogram_prefill_kv_computed_request = create_metric_per_engine(
histogram_prefill_kv_computed_request, per_engine_labelvalues
)
#
@@ -939,8 +940,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=kv_cache_residency_buckets,
labelnames=labelnames,
)
self.histogram_kv_block_lifetime = make_per_engine(
histogram_kv_block_lifetime, engine_indexes, model_name
self.histogram_kv_block_lifetime = create_metric_per_engine(
histogram_kv_block_lifetime, per_engine_labelvalues
)
histogram_kv_block_idle_before_evict = self._histogram_cls(
@@ -952,8 +953,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=kv_cache_residency_buckets,
labelnames=labelnames,
)
self.histogram_kv_block_idle_before_evict = make_per_engine(
histogram_kv_block_idle_before_evict, engine_indexes, model_name
self.histogram_kv_block_idle_before_evict = create_metric_per_engine(
histogram_kv_block_idle_before_evict, per_engine_labelvalues
)
histogram_kv_block_reuse_gap = self._histogram_cls(
@@ -967,8 +968,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
buckets=kv_cache_residency_buckets,
labelnames=labelnames,
)
self.histogram_kv_block_reuse_gap = make_per_engine(
histogram_kv_block_reuse_gap, engine_indexes, model_name
self.histogram_kv_block_reuse_gap = create_metric_per_engine(
histogram_kv_block_reuse_gap, per_engine_labelvalues
)
else:
self.histogram_kv_block_lifetime = {}
@@ -1203,15 +1204,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
PromMetric: TypeAlias = Gauge | Counter | Histogram
def make_per_engine(
metric: PromMetric, engine_idxs: list[int], model_name: object
) -> dict[int, PromMetric]:
return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs}
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
"""
Builds a list of buckets with increasing powers of 10 multiplied by

View File

@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import (
get_kv_cache_torch_dtype,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.metrics.utils import create_metric_per_engine
logger = init_logger(__name__)
@@ -1291,7 +1292,9 @@ class PerfMetricsProm:
),
labelnames=labelnames,
)
self.counter_flops = make_per_engine(counter_flops, per_engine_labelvalues)
self.counter_flops = create_metric_per_engine(
counter_flops, per_engine_labelvalues
)
counter_read_bytes = self._counter_cls(
name="vllm:estimated_read_bytes_per_gpu_total",
@@ -1301,7 +1304,7 @@ class PerfMetricsProm:
),
labelnames=labelnames,
)
self.counter_read_bytes = make_per_engine(
self.counter_read_bytes = create_metric_per_engine(
counter_read_bytes, per_engine_labelvalues
)
@@ -1313,7 +1316,7 @@ class PerfMetricsProm:
),
labelnames=labelnames,
)
self.counter_write_bytes = make_per_engine(
self.counter_write_bytes = create_metric_per_engine(
counter_write_bytes, per_engine_labelvalues
)
@@ -1329,16 +1332,6 @@ class PerfMetricsProm:
self.counter_write_bytes[engine_idx].inc(perf_stats.num_write_bytes_per_gpu)
def make_per_engine(
counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[object]]
):
"""Create a counter for each label value."""
return {
idx: counter.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}
## util functions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorPrometheus
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorProm
from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.metrics.perf import PerfMetricsProm
from vllm.v1.spec_decode.metrics import SpecDecodingProm
@@ -168,9 +168,9 @@ class RaySpecDecodingProm(SpecDecodingProm):
_counter_cls = RayCounterWrapper
class RayKVConnectorPrometheus(KVConnectorPrometheus):
class RayKVConnectorProm(KVConnectorProm):
"""
RayKVConnectorPrometheus is used by RayMetrics to log Ray
RayKVConnectorProm is used by RayMetrics to log Ray
metrics. Provides the same metrics as KV connectors but
uses Ray's util.metrics library.
"""
@@ -197,7 +197,7 @@ class RayPrometheusStatLogger(PrometheusStatLogger):
_counter_cls = RayCounterWrapper
_histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm
_kv_connector_cls = RayKVConnectorPrometheus
_kv_connector_cls = RayKVConnectorProm
_perf_metrics_cls = RayPerfMetricsProm
@staticmethod

19
vllm/v1/metrics/utils.py Normal file
View File

@@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeAlias
from prometheus_client import Counter, Gauge, Histogram
PromMetric: TypeAlias = Gauge | Counter | Histogram
def create_metric_per_engine(
metric: PromMetric,
per_engine_labelvalues: dict[int, list[object]],
) -> dict[int, PromMetric]:
"""Create a labeled metric child for each engine index."""
return {
idx: metric.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}

View File

@@ -9,6 +9,7 @@ import prometheus_client
from vllm.config import SpeculativeConfig
from vllm.logger import init_logger
from vllm.v1.metrics.utils import create_metric_per_engine
logger = init_logger(__name__)
@@ -155,7 +156,7 @@ class SpecDecodingProm:
documentation="Number of spec decoding drafts.",
labelnames=labelnames,
)
self.counter_spec_decode_num_drafts = make_per_engine(
self.counter_spec_decode_num_drafts = create_metric_per_engine(
counter_drafts, per_engine_labelvalues
)
@@ -164,7 +165,7 @@ class SpecDecodingProm:
documentation="Number of draft tokens.",
labelnames=labelnames,
)
self.counter_spec_decode_num_draft_tokens = make_per_engine(
self.counter_spec_decode_num_draft_tokens = create_metric_per_engine(
counter_draft_tokens, per_engine_labelvalues
)
@@ -173,7 +174,7 @@ class SpecDecodingProm:
documentation="Number of accepted tokens.",
labelnames=labelnames,
)
self.counter_spec_decode_num_accepted_tokens = make_per_engine(
self.counter_spec_decode_num_accepted_tokens = create_metric_per_engine(
counter_accepted_tokens, per_engine_labelvalues
)
@@ -212,14 +213,3 @@ class SpecDecodingProm:
self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]
):
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
def make_per_engine(
counter: prometheus_client.Counter,
per_engine_labelvalues: dict[int, list[object]],
):
"""Create a counter for each label value."""
return {
idx: counter.labels(*labelvalues)
for idx, labelvalues in per_engine_labelvalues.items()
}