diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py index db77d41c4..faaffd72e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py @@ -126,28 +126,17 @@ class KVConnectorPromMetrics: self._labelnames = labelnames self.per_engine_labelvalues = per_engine_labelvalues - def make_per_engine(self, metric: PromMetric) -> dict[int, PromMetric]: - """ - Create a per-engine child of a prometheus_client.Metric with - the appropriate labels set. The parent metric must be created - using the labelnames list. - """ - return { - idx: metric.labels(*labelvalues) - for idx, labelvalues in self.per_engine_labelvalues.items() - } - def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): """ Record the supplied transfer statistics to Prometheus metrics. These statistics are engine-specific, and should be recorded to a metric with the appropriate 'engine' label. These metric instances can be - created using the make_per_engine() helper method. + created using the create_metric_per_engine() helper method. """ raise NotImplementedError -class KVConnectorPrometheus: +class KVConnectorProm: """ Support for registering per-connector Prometheus metrics, and recording transfer statistics to those metrics. Uses diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ed53c35c9..a86a52a6a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -65,6 +65,7 @@ from vllm.v1.kv_cache_interface import ( SlidingWindowSpec, UniformTypeKVCacheSpecs, ) +from vllm.v1.metrics.utils import create_metric_per_engine from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.utils import select_common_block_size @@ -3057,7 +3058,9 @@ class NixlPromMetrics(KVConnectorPromMetrics): buckets=buckets[1:], labelnames=labelnames, ) - self.nixl_histogram_xfer_time = self.make_per_engine(nixl_histogram_xfer_time) + self.nixl_histogram_xfer_time = create_metric_per_engine( + nixl_histogram_xfer_time, self.per_engine_labelvalues + ) nixl_histogram_post_time = self._histogram_cls( name="vllm:nixl_post_time_seconds", documentation="Histogram of transfer post time for NIXL KV" @@ -3065,7 +3068,9 @@ class NixlPromMetrics(KVConnectorPromMetrics): buckets=buckets, labelnames=labelnames, ) - self.nixl_histogram_post_time = self.make_per_engine(nixl_histogram_post_time) + self.nixl_histogram_post_time = create_metric_per_engine( + nixl_histogram_post_time, self.per_engine_labelvalues + ) # uniform 2kb to 16gb range buckets = [2 ** (10 + i) for i in range(1, 25, 2)] nixl_histogram_bytes_transferred = self._histogram_cls( @@ -3074,8 +3079,8 @@ class NixlPromMetrics(KVConnectorPromMetrics): buckets=buckets, labelnames=labelnames, ) - self.nixl_histogram_bytes_transferred = self.make_per_engine( - nixl_histogram_bytes_transferred + self.nixl_histogram_bytes_transferred = create_metric_per_engine( + nixl_histogram_bytes_transferred, self.per_engine_labelvalues ) buckets = [ 10, @@ -3100,24 +3105,24 @@ class NixlPromMetrics(KVConnectorPromMetrics): buckets=buckets, labelnames=labelnames, ) - self.nixl_histogram_num_descriptors = self.make_per_engine( - nixl_histogram_num_descriptors + self.nixl_histogram_num_descriptors = create_metric_per_engine( + nixl_histogram_num_descriptors, self.per_engine_labelvalues ) counter_nixl_num_failed_transfers = self._counter_cls( name="vllm:nixl_num_failed_transfers", documentation="Number of failed NIXL KV Cache transfers.", labelnames=labelnames, ) - self.counter_nixl_num_failed_transfers = self.make_per_engine( - counter_nixl_num_failed_transfers + self.counter_nixl_num_failed_transfers = create_metric_per_engine( + counter_nixl_num_failed_transfers, self.per_engine_labelvalues ) counter_nixl_num_failed_notifications = self._counter_cls( name="vllm:nixl_num_failed_notifications", documentation="Number of failed NIXL KV Cache notifications.", labelnames=labelnames, ) - self.counter_nixl_num_failed_notifications = self.make_per_engine( - counter_nixl_num_failed_notifications + self.counter_nixl_num_failed_notifications = create_metric_per_engine( + counter_nixl_num_failed_notifications, self.per_engine_labelvalues ) counter_nixl_num_kv_expired_reqs = self._counter_cls( @@ -3126,8 +3131,8 @@ class NixlPromMetrics(KVConnectorPromMetrics): "NOTE: This metric is tracked on the P instance.", labelnames=labelnames, ) - self.counter_nixl_num_kv_expired_reqs = self.make_per_engine( - counter_nixl_num_kv_expired_reqs + self.counter_nixl_num_kv_expired_reqs = create_metric_per_engine( + counter_nixl_num_kv_expired_reqs, self.per_engine_labelvalues ) def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index f20d78542..5d5877d16 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -5,7 +5,6 @@ import logging import time from abc import ABC, abstractmethod from collections.abc import Callable -from typing import TypeAlias from prometheus_client import Counter, Gauge, Histogram @@ -14,7 +13,7 @@ from vllm.compilation.cuda_graph import CUDAGraphLogging from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorLogging, - KVConnectorPrometheus, + KVConnectorProm, ) from vllm.logger import init_logger from vllm.plugins import STAT_LOGGER_PLUGINS_GROUP, load_plugins_by_group @@ -28,6 +27,7 @@ from vllm.v1.metrics.stats import ( PromptTokenStats, SchedulerStats, ) +from vllm.v1.metrics.utils import create_metric_per_engine from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm logger = init_logger(__name__) @@ -391,7 +391,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase): _counter_cls = Counter _histogram_cls = Histogram _spec_decoding_cls = SpecDecodingProm - _kv_connector_cls = KVConnectorPrometheus + _kv_connector_cls = KVConnectorProm _perf_metrics_cls = PerfMetricsProm def __init__( @@ -415,9 +415,10 @@ class PrometheusStatLogger(AggregateStatLoggerBase): model_name = vllm_config.model_config.served_model_name max_model_len = vllm_config.model_config.max_model_len - per_engine_labelvalues: dict[int, list[object]] = { + self.per_engine_labelvalues: dict[int, list[object]] = { idx: [model_name, str(idx)] for idx in engine_indexes } + per_engine_labelvalues = self.per_engine_labelvalues self.spec_decoding_prom = self._spec_decoding_cls( vllm_config.speculative_config, labelnames, per_engine_labelvalues @@ -438,8 +439,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): multiprocess_mode="mostrecent", labelnames=labelnames, ) - self.gauge_scheduler_running = make_per_engine( - gauge_scheduler_running, engine_indexes, model_name + self.gauge_scheduler_running = create_metric_per_engine( + gauge_scheduler_running, per_engine_labelvalues ) gauge_scheduler_waiting = self._gauge_cls( @@ -448,8 +449,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): multiprocess_mode="mostrecent", labelnames=labelnames, ) - self.gauge_scheduler_waiting = make_per_engine( - gauge_scheduler_waiting, engine_indexes, model_name + self.gauge_scheduler_waiting = create_metric_per_engine( + gauge_scheduler_waiting, per_engine_labelvalues ) gauge_engine_sleep_state = self._gauge_cls( @@ -484,8 +485,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): multiprocess_mode="mostrecent", labelnames=labelnames, ) - self.gauge_kv_cache_usage = make_per_engine( - gauge_kv_cache_usage, engine_indexes, model_name + self.gauge_kv_cache_usage = create_metric_per_engine( + gauge_kv_cache_usage, per_engine_labelvalues ) if envs.VLLM_COMPUTE_NANS_IN_LOGITS: @@ -497,8 +498,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ), labelnames=labelnames, ) - self.counter_corrupted_requests = make_per_engine( - counter_corrupted_requests, engine_indexes, model_name + self.counter_corrupted_requests = create_metric_per_engine( + counter_corrupted_requests, per_engine_labelvalues ) counter_prefix_cache_queries = self._counter_cls( @@ -508,8 +509,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ), labelnames=labelnames, ) - self.counter_prefix_cache_queries = make_per_engine( - counter_prefix_cache_queries, engine_indexes, model_name + self.counter_prefix_cache_queries = create_metric_per_engine( + counter_prefix_cache_queries, per_engine_labelvalues ) counter_prefix_cache_hits = self._counter_cls( @@ -517,8 +518,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): documentation=("Prefix cache hits, in terms of number of cached tokens."), labelnames=labelnames, ) - self.counter_prefix_cache_hits = make_per_engine( - counter_prefix_cache_hits, engine_indexes, model_name + self.counter_prefix_cache_hits = create_metric_per_engine( + counter_prefix_cache_hits, per_engine_labelvalues ) # @@ -533,8 +534,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ), labelnames=labelnames, ) - self.counter_connector_prefix_cache_queries = make_per_engine( - counter_connector_prefix_cache_queries, engine_indexes, model_name + self.counter_connector_prefix_cache_queries = create_metric_per_engine( + counter_connector_prefix_cache_queries, per_engine_labelvalues ) counter_connector_prefix_cache_hits = self._counter_cls( @@ -545,8 +546,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ), labelnames=labelnames, ) - self.counter_connector_prefix_cache_hits = make_per_engine( - counter_connector_prefix_cache_hits, engine_indexes, model_name + self.counter_connector_prefix_cache_hits = create_metric_per_engine( + counter_connector_prefix_cache_hits, per_engine_labelvalues ) # @@ -560,8 +561,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ), labelnames=labelnames, ) - self.counter_mm_cache_queries = make_per_engine( - counter_mm_cache_queries, engine_indexes, model_name + self.counter_mm_cache_queries = create_metric_per_engine( + counter_mm_cache_queries, per_engine_labelvalues ) counter_mm_cache_hits = self._counter_cls( @@ -571,8 +572,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ), labelnames=labelnames, ) - self.counter_mm_cache_hits = make_per_engine( - counter_mm_cache_hits, engine_indexes, model_name + self.counter_mm_cache_hits = create_metric_per_engine( + counter_mm_cache_hits, per_engine_labelvalues ) # @@ -583,8 +584,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): documentation="Cumulative number of preemption from the engine.", labelnames=labelnames, ) - self.counter_num_preempted_reqs = make_per_engine( - counter_num_preempted_reqs, engine_indexes, model_name + self.counter_num_preempted_reqs = create_metric_per_engine( + counter_num_preempted_reqs, per_engine_labelvalues ) counter_prompt_tokens = self._counter_cls( @@ -592,8 +593,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): documentation="Number of prefill tokens processed.", labelnames=labelnames, ) - self.counter_prompt_tokens = make_per_engine( - counter_prompt_tokens, engine_indexes, model_name + self.counter_prompt_tokens = create_metric_per_engine( + counter_prompt_tokens, per_engine_labelvalues ) # Labeled prompt token counters by source @@ -617,8 +618,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): documentation="Number of cached prompt tokens (local + external).", labelnames=labelnames, ) - self.counter_prompt_tokens_cached = make_per_engine( - counter_prompt_tokens_cached, engine_indexes, model_name + self.counter_prompt_tokens_cached = create_metric_per_engine( + counter_prompt_tokens_cached, per_engine_labelvalues ) # Recomputed tokens (last token recomputed when entire prompt is cached) @@ -627,8 +628,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): documentation="Number of cached tokens recomputed for forward pass.", labelnames=labelnames, ) - self.counter_prompt_tokens_recomputed = make_per_engine( - counter_prompt_tokens_recomputed, engine_indexes, model_name + self.counter_prompt_tokens_recomputed = create_metric_per_engine( + counter_prompt_tokens_recomputed, per_engine_labelvalues ) counter_generation_tokens = self._counter_cls( @@ -636,8 +637,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): documentation="Number of generation tokens processed.", labelnames=labelnames, ) - self.counter_generation_tokens = make_per_engine( - counter_generation_tokens, engine_indexes, model_name + self.counter_generation_tokens = create_metric_per_engine( + counter_generation_tokens, per_engine_labelvalues ) self.counter_request_success: dict[FinishReason, dict[int, Counter]] = {} @@ -663,8 +664,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames, ) - self.histogram_num_prompt_tokens_request = make_per_engine( - histogram_num_prompt_tokens_request, engine_indexes, model_name + self.histogram_num_prompt_tokens_request = create_metric_per_engine( + histogram_num_prompt_tokens_request, per_engine_labelvalues ) histogram_num_generation_tokens_request = self._histogram_cls( @@ -673,8 +674,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames, ) - self.histogram_num_generation_tokens_request = make_per_engine( - histogram_num_generation_tokens_request, engine_indexes, model_name + self.histogram_num_generation_tokens_request = create_metric_per_engine( + histogram_num_generation_tokens_request, per_engine_labelvalues ) # TODO: This metric might be incorrect in case of using multiple @@ -686,8 +687,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], labelnames=labelnames, ) - self.histogram_iteration_tokens = make_per_engine( - histogram_iteration_tokens, engine_indexes, model_name + self.histogram_iteration_tokens = create_metric_per_engine( + histogram_iteration_tokens, per_engine_labelvalues ) histogram_max_num_generation_tokens_request = self._histogram_cls( @@ -696,8 +697,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames, ) - self.histogram_max_num_generation_tokens_request = make_per_engine( - histogram_max_num_generation_tokens_request, engine_indexes, model_name + self.histogram_max_num_generation_tokens_request = create_metric_per_engine( + histogram_max_num_generation_tokens_request, per_engine_labelvalues ) histogram_n_request = self._histogram_cls( @@ -706,8 +707,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=[1, 2, 5, 10, 20], labelnames=labelnames, ) - self.histogram_n_request = make_per_engine( - histogram_n_request, engine_indexes, model_name + self.histogram_n_request = create_metric_per_engine( + histogram_n_request, per_engine_labelvalues ) histogram_max_tokens_request = self._histogram_cls( @@ -716,8 +717,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames, ) - self.histogram_max_tokens_request = make_per_engine( - histogram_max_tokens_request, engine_indexes, model_name + self.histogram_max_tokens_request = create_metric_per_engine( + histogram_max_tokens_request, per_engine_labelvalues ) # @@ -752,8 +753,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ], labelnames=labelnames, ) - self.histogram_time_to_first_token = make_per_engine( - histogram_time_to_first_token, engine_indexes, model_name + self.histogram_time_to_first_token = create_metric_per_engine( + histogram_time_to_first_token, per_engine_labelvalues ) histogram_inter_token_latency = self._histogram_cls( @@ -782,8 +783,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ], labelnames=labelnames, ) - self.histogram_inter_token_latency = make_per_engine( - histogram_inter_token_latency, engine_indexes, model_name + self.histogram_inter_token_latency = create_metric_per_engine( + histogram_inter_token_latency, per_engine_labelvalues ) histogram_request_time_per_output_token = self._histogram_cls( @@ -812,8 +813,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): ], labelnames=labelnames, ) - self.histogram_request_time_per_output_token = make_per_engine( - histogram_request_time_per_output_token, engine_indexes, model_name + self.histogram_request_time_per_output_token = create_metric_per_engine( + histogram_request_time_per_output_token, per_engine_labelvalues ) request_latency_buckets = [ @@ -845,8 +846,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=request_latency_buckets, labelnames=labelnames, ) - self.histogram_e2e_time_request = make_per_engine( - histogram_e2e_time_request, engine_indexes, model_name + self.histogram_e2e_time_request = create_metric_per_engine( + histogram_e2e_time_request, per_engine_labelvalues ) histogram_queue_time_request = self._histogram_cls( @@ -855,8 +856,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=request_latency_buckets, labelnames=labelnames, ) - self.histogram_queue_time_request = make_per_engine( - histogram_queue_time_request, engine_indexes, model_name + self.histogram_queue_time_request = create_metric_per_engine( + histogram_queue_time_request, per_engine_labelvalues ) histogram_inference_time_request = self._histogram_cls( @@ -865,8 +866,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=request_latency_buckets, labelnames=labelnames, ) - self.histogram_inference_time_request = make_per_engine( - histogram_inference_time_request, engine_indexes, model_name + self.histogram_inference_time_request = create_metric_per_engine( + histogram_inference_time_request, per_engine_labelvalues ) histogram_prefill_time_request = self._histogram_cls( @@ -875,8 +876,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=request_latency_buckets, labelnames=labelnames, ) - self.histogram_prefill_time_request = make_per_engine( - histogram_prefill_time_request, engine_indexes, model_name + self.histogram_prefill_time_request = create_metric_per_engine( + histogram_prefill_time_request, per_engine_labelvalues ) histogram_decode_time_request = self._histogram_cls( @@ -885,8 +886,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=request_latency_buckets, labelnames=labelnames, ) - self.histogram_decode_time_request = make_per_engine( - histogram_decode_time_request, engine_indexes, model_name + self.histogram_decode_time_request = create_metric_per_engine( + histogram_decode_time_request, per_engine_labelvalues ) histogram_prefill_kv_computed_request = self._histogram_cls( @@ -898,8 +899,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=build_1_2_5_buckets(max_model_len), labelnames=labelnames, ) - self.histogram_prefill_kv_computed_request = make_per_engine( - histogram_prefill_kv_computed_request, engine_indexes, model_name + self.histogram_prefill_kv_computed_request = create_metric_per_engine( + histogram_prefill_kv_computed_request, per_engine_labelvalues ) # @@ -939,8 +940,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=kv_cache_residency_buckets, labelnames=labelnames, ) - self.histogram_kv_block_lifetime = make_per_engine( - histogram_kv_block_lifetime, engine_indexes, model_name + self.histogram_kv_block_lifetime = create_metric_per_engine( + histogram_kv_block_lifetime, per_engine_labelvalues ) histogram_kv_block_idle_before_evict = self._histogram_cls( @@ -952,8 +953,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=kv_cache_residency_buckets, labelnames=labelnames, ) - self.histogram_kv_block_idle_before_evict = make_per_engine( - histogram_kv_block_idle_before_evict, engine_indexes, model_name + self.histogram_kv_block_idle_before_evict = create_metric_per_engine( + histogram_kv_block_idle_before_evict, per_engine_labelvalues ) histogram_kv_block_reuse_gap = self._histogram_cls( @@ -967,8 +968,8 @@ class PrometheusStatLogger(AggregateStatLoggerBase): buckets=kv_cache_residency_buckets, labelnames=labelnames, ) - self.histogram_kv_block_reuse_gap = make_per_engine( - histogram_kv_block_reuse_gap, engine_indexes, model_name + self.histogram_kv_block_reuse_gap = create_metric_per_engine( + histogram_kv_block_reuse_gap, per_engine_labelvalues ) else: self.histogram_kv_block_lifetime = {} @@ -1203,15 +1204,6 @@ class PrometheusStatLogger(AggregateStatLoggerBase): self.log_metrics_info("cache_config", self.vllm_config.cache_config) -PromMetric: TypeAlias = Gauge | Counter | Histogram - - -def make_per_engine( - metric: PromMetric, engine_idxs: list[int], model_name: object -) -> dict[int, PromMetric]: - return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs} - - def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]: """ Builds a list of buckets with increasing powers of 10 multiplied by diff --git a/vllm/v1/metrics/perf.py b/vllm/v1/metrics/perf.py index 81348efc1..91629cb57 100644 --- a/vllm/v1/metrics/perf.py +++ b/vllm/v1/metrics/perf.py @@ -27,6 +27,7 @@ from vllm.utils.torch_utils import ( get_kv_cache_torch_dtype, ) from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.metrics.utils import create_metric_per_engine logger = init_logger(__name__) @@ -1291,7 +1292,9 @@ class PerfMetricsProm: ), labelnames=labelnames, ) - self.counter_flops = make_per_engine(counter_flops, per_engine_labelvalues) + self.counter_flops = create_metric_per_engine( + counter_flops, per_engine_labelvalues + ) counter_read_bytes = self._counter_cls( name="vllm:estimated_read_bytes_per_gpu_total", @@ -1301,7 +1304,7 @@ class PerfMetricsProm: ), labelnames=labelnames, ) - self.counter_read_bytes = make_per_engine( + self.counter_read_bytes = create_metric_per_engine( counter_read_bytes, per_engine_labelvalues ) @@ -1313,7 +1316,7 @@ class PerfMetricsProm: ), labelnames=labelnames, ) - self.counter_write_bytes = make_per_engine( + self.counter_write_bytes = create_metric_per_engine( counter_write_bytes, per_engine_labelvalues ) @@ -1329,16 +1332,6 @@ class PerfMetricsProm: self.counter_write_bytes[engine_idx].inc(perf_stats.num_write_bytes_per_gpu) -def make_per_engine( - counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[object]] -): - """Create a counter for each label value.""" - return { - idx: counter.labels(*labelvalues) - for idx, labelvalues in per_engine_labelvalues.items() - } - - ## util functions diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index abc53f380..a11b92680 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorPrometheus +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorProm from vllm.v1.metrics.loggers import PrometheusStatLogger from vllm.v1.metrics.perf import PerfMetricsProm from vllm.v1.spec_decode.metrics import SpecDecodingProm @@ -168,9 +168,9 @@ class RaySpecDecodingProm(SpecDecodingProm): _counter_cls = RayCounterWrapper -class RayKVConnectorPrometheus(KVConnectorPrometheus): +class RayKVConnectorProm(KVConnectorProm): """ - RayKVConnectorPrometheus is used by RayMetrics to log Ray + RayKVConnectorProm is used by RayMetrics to log Ray metrics. Provides the same metrics as KV connectors but uses Ray's util.metrics library. """ @@ -197,7 +197,7 @@ class RayPrometheusStatLogger(PrometheusStatLogger): _counter_cls = RayCounterWrapper _histogram_cls = RayHistogramWrapper _spec_decoding_cls = RaySpecDecodingProm - _kv_connector_cls = RayKVConnectorPrometheus + _kv_connector_cls = RayKVConnectorProm _perf_metrics_cls = RayPerfMetricsProm @staticmethod diff --git a/vllm/v1/metrics/utils.py b/vllm/v1/metrics/utils.py new file mode 100644 index 000000000..1ef56fc94 --- /dev/null +++ b/vllm/v1/metrics/utils.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TypeAlias + +from prometheus_client import Counter, Gauge, Histogram + +PromMetric: TypeAlias = Gauge | Counter | Histogram + + +def create_metric_per_engine( + metric: PromMetric, + per_engine_labelvalues: dict[int, list[object]], +) -> dict[int, PromMetric]: + """Create a labeled metric child for each engine index.""" + return { + idx: metric.labels(*labelvalues) + for idx, labelvalues in per_engine_labelvalues.items() + } diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 6c16bc686..9a41ff5c8 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -9,6 +9,7 @@ import prometheus_client from vllm.config import SpeculativeConfig from vllm.logger import init_logger +from vllm.v1.metrics.utils import create_metric_per_engine logger = init_logger(__name__) @@ -155,7 +156,7 @@ class SpecDecodingProm: documentation="Number of spec decoding drafts.", labelnames=labelnames, ) - self.counter_spec_decode_num_drafts = make_per_engine( + self.counter_spec_decode_num_drafts = create_metric_per_engine( counter_drafts, per_engine_labelvalues ) @@ -164,7 +165,7 @@ class SpecDecodingProm: documentation="Number of draft tokens.", labelnames=labelnames, ) - self.counter_spec_decode_num_draft_tokens = make_per_engine( + self.counter_spec_decode_num_draft_tokens = create_metric_per_engine( counter_draft_tokens, per_engine_labelvalues ) @@ -173,7 +174,7 @@ class SpecDecodingProm: documentation="Number of accepted tokens.", labelnames=labelnames, ) - self.counter_spec_decode_num_accepted_tokens = make_per_engine( + self.counter_spec_decode_num_accepted_tokens = create_metric_per_engine( counter_accepted_tokens, per_engine_labelvalues ) @@ -212,14 +213,3 @@ class SpecDecodingProm: self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx] ): counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) - - -def make_per_engine( - counter: prometheus_client.Counter, - per_engine_labelvalues: dict[int, list[object]], -): - """Create a counter for each label value.""" - return { - idx: counter.labels(*labelvalues) - for idx, labelvalues in per_engine_labelvalues.items() - }