[V1][Metrics] Add GPU prefix cache hit rate % gauge (#12592)

This commit is contained in:
Cody Yu
2025-02-11 00:27:25 -08:00
committed by GitHub
parent fc6485d277
commit 41c5dd45b9
7 changed files with 174 additions and 5 deletions

View File

@@ -203,6 +203,8 @@ EXPECTED_METRICS_V1 = [
"vllm:num_requests_running", "vllm:num_requests_running",
"vllm:num_requests_waiting", "vllm:num_requests_waiting",
"vllm:gpu_cache_usage_perc", "vllm:gpu_cache_usage_perc",
"vllm:gpu_prefix_cache_queries",
"vllm:gpu_prefix_cache_hits",
"vllm:prompt_tokens_total", "vllm:prompt_tokens_total",
"vllm:generation_tokens_total", "vllm:generation_tokens_total",
"vllm:request_success_total", "vllm:request_success_total",

View File

@@ -5,10 +5,11 @@ import pytest
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, KVCacheBlock, PrefixCachingMetrics,
generate_block_hash_extra_keys, generate_block_hash_extra_keys,
hash_block_tokens, hash_block_tokens,
hash_request_tokens) hash_request_tokens)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
@@ -277,3 +278,39 @@ def test_hash_request_tokens_no_mm_inputs():
assert block_hashes[0].extra_keys is None assert block_hashes[0].extra_keys is None
assert block_hashes[1].token_ids == (3, 4, 5) assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys is None assert block_hashes[1].extra_keys is None
def test_metrics():
"""
Test the prefix caching metrics.
"""
def stats(requests, queries, hits):
return PrefixCacheStats(requests=requests, queries=queries, hits=hits)
metrics = PrefixCachingMetrics(interval=5)
assert metrics.hit_rate == 0.0
metrics.observe(stats(1, 20, 9))
# 9 / 20 = 0.45
assert metrics.hit_rate == 0.45
metrics.observe(stats(4, 80, 16))
# 25 / 100 = 0.25
assert metrics.hit_rate == 0.25
metrics.observe(stats(1, 10, 2))
# Remove (20, 9) and add (10, 2): 18 / 90 = 0.2
assert metrics.aggregated_requests == 5
assert metrics.aggregated_query_total == 90
assert metrics.aggregated_query_hit == 18
assert metrics.hit_rate == 0.2
metrics.reset()
assert metrics.hit_rate == 0.0
assert metrics.aggregated_requests == 0
assert metrics.aggregated_query_total == 0
assert metrics.aggregated_query_hit == 0
assert not metrics.query_queue

View File

@@ -10,6 +10,7 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
generate_block_hash_extra_keys, generate_block_hash_extra_keys,
hash_block_tokens, hash_block_tokens,
hash_request_tokens) hash_request_tokens)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -78,11 +79,28 @@ class KVCacheManager:
self.req_to_block_hashes: DefaultDict[ self.req_to_block_hashes: DefaultDict[
str, List[BlockHashType]] = defaultdict(list) str, List[BlockHashType]] = defaultdict(list)
self.prefix_cache_stats = PrefixCacheStats()
@property @property
def usage(self) -> float: def usage(self) -> float:
"""Get the KV cache usage.
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
return 1.0 - (self.free_block_queue.num_free_blocks / return 1.0 - (self.free_block_queue.num_free_blocks /
self.num_gpu_blocks) self.num_gpu_blocks)
def make_prefix_cache_stats(self) -> PrefixCacheStats:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats.
"""
stats = self.prefix_cache_stats
self.prefix_cache_stats = PrefixCacheStats()
return stats
def get_computed_blocks( def get_computed_blocks(
self, request: Request) -> Tuple[List[KVCacheBlock], int]: self, request: Request) -> Tuple[List[KVCacheBlock], int]:
"""Get the computed (cached) blocks for the request. """Get the computed (cached) blocks for the request.
@@ -118,6 +136,10 @@ class KVCacheManager:
else: else:
break break
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)
# NOTE(woosuk): Since incomplete blocks are not eligible for # NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of # sharing, `num_computed_tokens` is always a multiple of
# `block_size`. # `block_size`.
@@ -280,6 +302,8 @@ class KVCacheManager:
for block in self.block_pool: for block in self.block_pool:
block.reset_hash() block.reset_hash()
self.prefix_cache_stats.reset = True
logger.info("Successfully reset prefix cache") logger.info("Successfully reset prefix cache")
return True return True

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""KV-Cache Utilities.""" """KV-Cache Utilities."""
from collections import deque
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, NamedTuple, Optional, Tuple from typing import Any, List, NamedTuple, Optional, Tuple
@@ -8,6 +9,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
KVCacheTensor) KVCacheTensor)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -28,6 +30,68 @@ class BlockHashType(NamedTuple):
extra_keys: Optional[Any] = None extra_keys: Optional[Any] = None
class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the most recent N requests.
Args:
interval: The number of the most recent requests to aggregate.
Defaults to 1000.
"""
def __init__(self, interval: int = 1000):
self.interval = interval
# The current aggregated values.
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue: deque[Tuple[int, int, int]] = deque()
def observe(self, stats: PrefixCacheStats):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `interval` requests, the oldest set of
requestsare removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if stats.reset:
self.reset()
# Update the metrics.
self.query_queue.append((stats.requests, stats.queries, stats.hits))
self.aggregated_requests += stats.requests
self.aggregated_query_total += stats.queries
self.aggregated_query_hit += stats.hits
# Remove the oldest stats if the number of requests exceeds.
if self.aggregated_requests > self.interval:
old_requests, old_queries, old_hits = self.query_queue.popleft()
self.aggregated_requests -= old_requests
self.aggregated_query_total -= old_queries
self.aggregated_query_hit -= old_hits
def reset(self):
"""Reset the metrics."""
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.query_queue.clear()
@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total
@dataclass @dataclass
class KVCacheBlock: class KVCacheBlock:
"""KV-cache block metadata.""" """KV-cache block metadata."""

View File

@@ -593,4 +593,5 @@ class Scheduler:
num_running_reqs=len(self.running), num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting), num_waiting_reqs=len(self.waiting),
gpu_cache_usage=self.kv_cache_manager.usage, gpu_cache_usage=self.kv_cache_manager.usage,
prefix_cache_stats=self.kv_cache_manager.make_prefix_cache_stats(),
) )

View File

@@ -9,6 +9,7 @@ import prometheus_client
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import IterationStats, SchedulerStats
@@ -37,6 +38,9 @@ class LoggingStatLogger(StatLoggerBase):
self.num_prompt_tokens: List[int] = [] self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = [] self.num_generation_tokens: List[int] = []
# Prefix cache metrics. TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
def _local_interval_elapsed(self, now: float) -> bool: def _local_interval_elapsed(self, now: float) -> bool:
# Log every _LOCAL_LOGGING_INTERVAL_SEC. # Log every _LOCAL_LOGGING_INTERVAL_SEC.
elapsed_time = now - self.last_log_time elapsed_time = now - self.last_log_time
@@ -58,6 +62,8 @@ class LoggingStatLogger(StatLoggerBase):
self._track_iteration_stats(iteration_stats) self._track_iteration_stats(iteration_stats)
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
now = time.monotonic() now = time.monotonic()
if not self._local_interval_elapsed(now): if not self._local_interval_elapsed(now):
return return
@@ -72,13 +78,15 @@ class LoggingStatLogger(StatLoggerBase):
logger.info( logger.info(
"Avg prompt throughput: %.1f tokens/s, " "Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs " "Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%.", "GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
prompt_throughput, prompt_throughput,
generation_throughput, generation_throughput,
scheduler_stats.num_running_reqs, scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs, scheduler_stats.num_waiting_reqs,
scheduler_stats.gpu_cache_usage * 100, scheduler_stats.gpu_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
) )
@@ -107,6 +115,18 @@ class PrometheusStatLogger(StatLoggerBase):
documentation="GPU KV-cache usage. 1 means 100 percent usage.", documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_gpu_prefix_cache_queries = prometheus_client.Counter(
name="vllm:gpu_prefix_cache_queries",
documentation=
"GPU prefix cache queries, in terms of number of queried blocks.",
labelnames=labelnames).labels(*labelvalues)
self.counter_gpu_prefix_cache_hits = prometheus_client.Counter(
name="vllm:gpu_prefix_cache_hits",
documentation=
"GPU prefix cache hits, in terms of number of cached blocks.",
labelnames=labelnames).labels(*labelvalues)
self.counter_prompt_tokens = prometheus_client.Counter( self.counter_prompt_tokens = prometheus_client.Counter(
name="vllm:prompt_tokens_total", name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.", documentation="Number of prefill tokens processed.",
@@ -170,6 +190,11 @@ class PrometheusStatLogger(StatLoggerBase):
self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage)
self.counter_gpu_prefix_cache_queries.inc(
scheduler_stats.prefix_cache_stats.queries)
self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits)
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
self.counter_generation_tokens.inc( self.counter_generation_tokens.inc(
iteration_stats.num_generation_tokens) iteration_stats.num_generation_tokens)

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import time import time
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -9,6 +9,20 @@ if TYPE_CHECKING:
from vllm.v1.engine import EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreOutput, FinishReason
@dataclass
class PrefixCacheStats:
"""Stores prefix cache hit statistics."""
# Whether reset_prefix_cache was invoked.
reset: bool = False
# The number of requests in this update.
requests: int = 0
# The number of queries in these requests. Note that "queries" here
# means the number of blocks that were queried from the cache.
queries: int = 0
# The number of hits in these requests.
hits: int = 0
@dataclass @dataclass
class SchedulerStats: class SchedulerStats:
"""Stats associated with the scheduler.""" """Stats associated with the scheduler."""
@@ -17,7 +31,9 @@ class SchedulerStats:
num_waiting_reqs: int = 0 num_waiting_reqs: int = 0
gpu_cache_usage: float = 0.0 gpu_cache_usage: float = 0.0
# gpu_prefix_cache_hit_rate: float = 0.0
prefix_cache_stats: PrefixCacheStats = field(
default_factory=PrefixCacheStats)
@dataclass @dataclass