diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index a319ffb1d..4b46669d5 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -7,37 +7,55 @@ from vllm.v1.metrics.loggers import PrometheusStatLogger from vllm.v1.spec_decode.metrics import SpecDecodingProm try: + from ray import serve as ray_serve from ray.util import metrics as ray_metrics from ray.util.metrics import Metric except ImportError: ray_metrics = None + ray_serve = None import regex as re +def _get_replica_id() -> str | None: + """Get the current Ray Serve replica ID, or None if not in a Serve context.""" + if ray_serve is None: + return None + try: + return ray_serve.get_replica_context().replica_id.unique_id + except ray_serve.exceptions.RayServeException: + return None + + class RayPrometheusMetric: def __init__(self): if ray_metrics is None: raise ImportError("RayPrometheusMetric requires Ray to be installed.") - self.metric: Metric = None + @staticmethod + def _get_tag_keys(labelnames: list[str] | None) -> tuple[str, ...]: + labels = list(labelnames) if labelnames else [] + labels.append("ReplicaId") + return tuple(labels) + def labels(self, *labels, **labelskwargs): + if labels: + # -1 because ReplicaId was added automatically + expected = len(self.metric._tag_keys) - 1 + if len(labels) != expected: + raise ValueError( + "Number of labels must match the number of tag keys. " + f"Expected {expected}, got {len(labels)}" + ) + labelskwargs.update(zip(self.metric._tag_keys, labels)) + + labelskwargs["ReplicaId"] = _get_replica_id() or "" + if labelskwargs: for k, v in labelskwargs.items(): if not isinstance(v, str): labelskwargs[k] = str(v) - self.metric.set_default_tags(labelskwargs) - - if labels: - if len(labels) != len(self.metric._tag_keys): - raise ValueError( - "Number of labels must match the number of tag keys. " - f"Expected {len(self.metric._tag_keys)}, got {len(labels)}" - ) - - self.metric.set_default_tags(dict(zip(self.metric._tag_keys, labels))) - return self @staticmethod @@ -71,10 +89,14 @@ class RayGaugeWrapper(RayPrometheusMetric): # "mostrecent", "all", "sum" do not apply. This logic can be manually # implemented at the observability layer (Prometheus/Grafana). del multiprocess_mode - labelnames_tuple = tuple(labelnames) if labelnames else None + + tag_keys = self._get_tag_keys(labelnames) name = self._get_sanitized_opentelemetry_name(name) + self.metric = ray_metrics.Gauge( - name=name, description=documentation, tag_keys=labelnames_tuple + name=name, + description=documentation, + tag_keys=tag_keys, ) def set(self, value: int | float): @@ -95,10 +117,12 @@ class RayCounterWrapper(RayPrometheusMetric): documentation: str | None = "", labelnames: list[str] | None = None, ): - labelnames_tuple = tuple(labelnames) if labelnames else None + tag_keys = self._get_tag_keys(labelnames) name = self._get_sanitized_opentelemetry_name(name) self.metric = ray_metrics.Counter( - name=name, description=documentation, tag_keys=labelnames_tuple + name=name, + description=documentation, + tag_keys=tag_keys, ) def inc(self, value: int | float = 1.0): @@ -118,13 +142,14 @@ class RayHistogramWrapper(RayPrometheusMetric): labelnames: list[str] | None = None, buckets: list[float] | None = None, ): - labelnames_tuple = tuple(labelnames) if labelnames else None + tag_keys = self._get_tag_keys(labelnames) name = self._get_sanitized_opentelemetry_name(name) + boundaries = buckets if buckets else [] self.metric = ray_metrics.Histogram( name=name, description=documentation, - tag_keys=labelnames_tuple, + tag_keys=tag_keys, boundaries=boundaries, )