Migrate docs from Sphinx to MkDocs (#18145)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
720
docs/design/v1/metrics.md
Normal file
720
docs/design/v1/metrics.md
Normal file
@@ -0,0 +1,720 @@
|
||||
# Metrics
|
||||
|
||||
Ensure the v1 LLM Engine exposes a superset of the metrics available in v0.
|
||||
|
||||
## Objectives
|
||||
|
||||
- Achieve parity of metrics between v0 and v1.
|
||||
- The priority use case is accessing these metrics via Prometheus as this is what we expect to be used in production environments.
|
||||
- Logging support - i.e. printing metrics to the info log - is provided for more ad-hoc testing, debugging, development, and exploratory use cases.
|
||||
|
||||
## Background
|
||||
|
||||
Metrics in vLLM can be categorized as follows:
|
||||
|
||||
1. Server-level metrics: these are global metrics that track the state and performance of the LLM engine. These are typically exposed as Gauges or Counters in Prometheus.
|
||||
2. Request-level metrics: these are metrics that track the characteristics - e.g. size and timing - of individual requests. These are typically exposed as Histograms in Prometheus, and are often the SLO that an SRE monitoring vLLM will be tracking.
|
||||
|
||||
The mental model is that the "Server-level Metrics" explain why the "Request-level Metrics" are what they are.
|
||||
|
||||
### v0 Metrics
|
||||
|
||||
In v0, the following metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix:
|
||||
|
||||
- `vllm:num_requests_running` (Gauge)
|
||||
- `vllm:num_requests_swapped` (Gauge)
|
||||
- `vllm:num_requests_waiting` (Gauge)
|
||||
- `vllm:gpu_cache_usage_perc` (Gauge)
|
||||
- `vllm:cpu_cache_usage_perc` (Gauge)
|
||||
- `vllm:gpu_prefix_cache_hit_rate` (Gauge)
|
||||
- `vllm:cpu_prefix_cache_hit_rate` (Gauge)
|
||||
- `vllm:prompt_tokens_total` (Counter)
|
||||
- `vllm:generation_tokens_total` (Counter)
|
||||
- `vllm:request_success_total` (Counter)
|
||||
- `vllm:request_prompt_tokens` (Histogram)
|
||||
- `vllm:request_generation_tokens` (Histogram)
|
||||
- `vllm:time_to_first_token_seconds` (Histogram)
|
||||
- `vllm:time_per_output_token_seconds` (Histogram)
|
||||
- `vllm:e2e_request_latency_seconds` (Histogram)
|
||||
- `vllm:request_queue_time_seconds` (Histogram)
|
||||
- `vllm:request_inference_time_seconds` (Histogram)
|
||||
- `vllm:request_prefill_time_seconds` (Histogram)
|
||||
- `vllm:request_decode_time_seconds` (Histogram)
|
||||
- `vllm:request_max_num_generation_tokens` (Histogram)
|
||||
- `vllm:num_preemptions_total` (Counter)
|
||||
- `vllm:cache_config_info` (Gauge)
|
||||
- `vllm:lora_requests_info` (Gauge)
|
||||
- `vllm:tokens_total` (Counter)
|
||||
- `vllm:iteration_tokens_total` (Histogram)
|
||||
- `vllm:time_in_queue_requests` (Histogram)
|
||||
- `vllm:model_forward_time_milliseconds` (Histogram)
|
||||
- `vllm:model_execute_time_milliseconds` (Histogram)
|
||||
- `vllm:request_params_n` (Histogram)
|
||||
- `vllm:request_params_max_tokens` (Histogram)
|
||||
- `vllm:spec_decode_draft_acceptance_rate` (Gauge)
|
||||
- `vllm:spec_decode_efficiency` (Gauge)
|
||||
- `vllm:spec_decode_num_accepted_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_draft_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_emitted_tokens_total` (Counter)
|
||||
|
||||
These are documented under [Inferencing and Serving -> Production Metrics](../../serving/metrics.md).
|
||||
|
||||
### Grafana Dashboard
|
||||
|
||||
vLLM also provides [a reference example](https://docs.vllm.ai/en/latest/getting_started/examples/prometheus_grafana.html) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard.
|
||||
|
||||
The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important:
|
||||
|
||||
- `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds
|
||||
- `vllm:prompt_tokens_total` - Prompt Tokens
|
||||
- `vllm:generation_tokens_total` - Generation Tokens
|
||||
- `vllm:time_per_output_token_seconds` - Inter token latency (Time Per Output Token, TPOT) in second.
|
||||
- `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds.
|
||||
- `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in RUNNING, WAITING, and SWAPPED state
|
||||
- `vllm:gpu_cache_usage_perc` - Percentage of used cache blocks by vLLM.
|
||||
- `vllm:request_prompt_tokens` - Request prompt length
|
||||
- `vllm:request_generation_tokens` - request generation length
|
||||
- `vllm:request_success_total` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached
|
||||
- `vllm:request_queue_time_seconds` - Queue Time
|
||||
- `vllm:request_prefill_time_seconds` - Requests Prefill Time
|
||||
- `vllm:request_decode_time_seconds` - Requests Decode Time
|
||||
- `vllm:request_max_num_generation_tokens` - Max Generation Token in Sequence Group
|
||||
|
||||
See [the PR which added this Dashboard](gh-pr:2316) for interesting and useful background on the choices made here.
|
||||
|
||||
### Prometheus Client Library
|
||||
|
||||
Prometheus support was initially added [using the aioprometheus library](gh-pr:1890), but a switch was made quickly to [prometheus_client](gh-pr:2730). The rationale is discussed in both linked PRs.
|
||||
|
||||
With the switch to `aioprometheus`, we lost a `MetricsMiddleware` to track HTTP metrics, but this was reinstated [using prometheus_fastapi_instrumentator](gh-pr:15657):
|
||||
|
||||
```bash
|
||||
$ curl http://0.0.0.0:8000/metrics 2>/dev/null | grep -P '^http_(?!.*(_bucket|_created|_sum)).*'
|
||||
http_requests_total{handler="/v1/completions",method="POST",status="2xx"} 201.0
|
||||
http_request_size_bytes_count{handler="/v1/completions"} 201.0
|
||||
http_response_size_bytes_count{handler="/v1/completions"} 201.0
|
||||
http_request_duration_highr_seconds_count 201.0
|
||||
http_request_duration_seconds_count{handler="/v1/completions",method="POST"} 201.0
|
||||
```
|
||||
|
||||
### Multi-process Mode
|
||||
|
||||
In v0, metrics are collected in the engine core process and we use multi-process mode to make them available in the API server process. See <gh-pr:7279>.
|
||||
|
||||
### Built in Python/Process Metrics
|
||||
|
||||
The following metrics are supported by default by `prometheus_client`, but the are not exposed with multiprocess mode is used:
|
||||
|
||||
- `python_gc_objects_collected_total`
|
||||
- `python_gc_objects_uncollectable_total`
|
||||
- `python_gc_collections_total`
|
||||
- `python_info`
|
||||
- `process_virtual_memory_bytes`
|
||||
- `process_resident_memory_bytes`
|
||||
- `process_start_time_seconds`
|
||||
- `process_cpu_seconds_total`
|
||||
- `process_open_fds`
|
||||
- `process_max_fds`
|
||||
|
||||
This is relevant because if we move away from multiprocess mode in v1,
|
||||
we get these back. However, it's questionable how relevant these are
|
||||
if they don't aggregate these stats for all processes that make up a
|
||||
vLLM instance.
|
||||
|
||||
### v0 PRs and Issues
|
||||
|
||||
For background, these are some of the relevant PRs which added the v0 metrics:
|
||||
|
||||
- <gh-pr:1890>
|
||||
- <gh-pr:2316>
|
||||
- <gh-pr:2730>
|
||||
- <gh-pr:4464>
|
||||
- <gh-pr:7279>
|
||||
|
||||
Also note the ["Even Better Observability"](gh-issue:3616) feature where e.g. [a detailed roadmap was laid out](gh-issue:3616#issuecomment-2030858781).
|
||||
|
||||
## v1 Design
|
||||
|
||||
### v1 PRs
|
||||
|
||||
For background, here are the relevant v1 PRs relating to the v1
|
||||
metrics issue <gh-issue:10582>:
|
||||
|
||||
- <gh-pr:11962>
|
||||
- <gh-pr:11973>
|
||||
- <gh-pr:10907>
|
||||
- <gh-pr:12416>
|
||||
- <gh-pr:12478>
|
||||
- <gh-pr:12516>
|
||||
- <gh-pr:12530>
|
||||
- <gh-pr:12561>
|
||||
- <gh-pr:12579>
|
||||
- <gh-pr:12592>
|
||||
- <gh-pr:12644>
|
||||
|
||||
### Metrics Collection
|
||||
|
||||
In v1, we wish to move computation and overhead out of the engine core
|
||||
process to minimize the time between each forward pass.
|
||||
|
||||
The overall idea of V1 EngineCore design is:
|
||||
- EngineCore is the inner loop. Performance is most critical here
|
||||
- AsyncLLM is the outer loop. This is overlapped with GPU execution
|
||||
(ideally), so this is where any "overheads" should be if
|
||||
possible. So AsyncLLM.output_handler_loop is the ideal place for the
|
||||
metrics bookkeeping if possible.
|
||||
|
||||
We will achieve this by collecting metrics in the frontend API server,
|
||||
and base these metrics on information we can glean from the
|
||||
`EngineCoreOutputs` returned by the engine core process to the
|
||||
frontend.
|
||||
|
||||
### Interval Calculations
|
||||
|
||||
Many of our metrics are the time interval between various events in
|
||||
the processing of a request. It is best practice to use timestamps
|
||||
based on "monotonic time" (`time.monotonic()`) rather than "wall-clock
|
||||
time" (`time.time()`) to calculate intervals as the former is
|
||||
unaffected by system clock changes (e.g. from NTP).
|
||||
|
||||
It's also important to note that monotonic clocks differ between
|
||||
processes - each process has its own reference. point. So it is
|
||||
meaningless to compare monotonic timestamps from different processes.
|
||||
|
||||
Therefore, in order to calculate an interval, we must compare two
|
||||
monotonic timestamps from the same process.
|
||||
|
||||
### Scheduler Stats
|
||||
|
||||
The engine core process will collect some key statistics from the
|
||||
scheduler - e.g. the number of requests that were scheduled or waiting
|
||||
after the last scheduler pass - and include those statistics in
|
||||
`EngineCoreOutputs`.
|
||||
|
||||
### Engine Core Events
|
||||
|
||||
The engine core will also record the timestamp of certain per-request
|
||||
events so that the frontend can calculate the interval between these
|
||||
events.
|
||||
|
||||
The events are:
|
||||
|
||||
- `QUEUED` - when the request was received by the engine core and
|
||||
added to the scheduler queue.
|
||||
- `SCHEDULED` - when the request was first scheduled for execution.
|
||||
- `PREEMPTED` - the request has been put back in the waiting queue
|
||||
in order to make room for other requests to complete. It will be
|
||||
re-scheduled in future and re-start its prefill phase.
|
||||
- `NEW_TOKENS` - when the output included in `EngineCoreOutput` was
|
||||
generated. Since this is common to all requests in a given
|
||||
iteration, we use a single timestamp on `EngineCoreOutputs` to
|
||||
record this event.
|
||||
|
||||
And the calculated intervals are:
|
||||
|
||||
- Queue interval - between `QUEUED` and most recent `SCHEDULED`.
|
||||
- Prefill interval - between most recent `SCHEDULED` and the subsequent
|
||||
first `NEW_TOKENS`.
|
||||
- Decode interval - between first (after the most recent `SCHEDULED`) and
|
||||
last `NEW_TOKENS`.
|
||||
- Inference interval - between most recent `SCHEDULED` and last `NEW_TOKENS`.
|
||||
- Inter-token interval - between successive `NEW_TOKENS`.
|
||||
|
||||
Put another way:
|
||||
|
||||

|
||||
|
||||
We explored the possibility of having the frontend calculate these
|
||||
intervals using the timing of events visible by the frontend. However,
|
||||
the frontend does not have visibility into the timing of the `QUEUED`
|
||||
and `SCHEDULED` events and, since we need to calculate intervals based
|
||||
on monotonic timestamps from the same process ... we need the engine
|
||||
core to record timestamps for all of these events.
|
||||
|
||||
#### Interval Calculations vs Preemptions
|
||||
|
||||
When a preemption occurs during decode, since any already generated
|
||||
tokens are reused, we consider the preemption as affecting the
|
||||
inter-token, decode, and inference intervals.
|
||||
|
||||

|
||||
|
||||
When a preemption occurs during prefill (assuming such an event
|
||||
is possible), we consider the preemption as affecting the
|
||||
time-to-first-token and prefill intervals.
|
||||
|
||||

|
||||
|
||||
### Frontend Stats Collection
|
||||
|
||||
As the frontend processes a single `EngineCoreOutputs` - i.e. the
|
||||
output from a single engine core iteration - it collects various
|
||||
statistics relating to that iteration:
|
||||
|
||||
- The total number of new tokens generated in this iteration.
|
||||
- The total number of prompt tokens processed by the prefills that
|
||||
completed in this iteration.
|
||||
- The queue intervals for any requests that were scheduled in this
|
||||
iteration.
|
||||
- The prefill intervals for any requests that completed prefill in
|
||||
this iteration.
|
||||
- The inter-token intervals (Time Per Output Token, TPOT), for all
|
||||
requests included in this iteration.
|
||||
- The Time-To-First-Token (TTFT) for any requests that completed
|
||||
prefill in this iteration. However, we calculate this interval
|
||||
relative to when the request was first received by the frontend
|
||||
(`arrival_time`) in order to account for input processing time.
|
||||
|
||||
For any requests that were completed in a given iteration, we also
|
||||
record:
|
||||
|
||||
- The inference and decode intervals - relative to the scheduled and
|
||||
first token events, as described above.
|
||||
- End-to-end latency - the interval between frontend `arrival_time`
|
||||
and the frontend receiving the final token.
|
||||
|
||||
### Metrics Publishing - Logging
|
||||
|
||||
The `LoggingStatLogger` metrics publisher outputs a log `INFO` message
|
||||
every 5 seconds with some key metrics:
|
||||
|
||||
- The current number of running/waiting requests
|
||||
- The current GPU cache usage
|
||||
- The number of prompt tokens processed per second over the past 5
|
||||
seconds
|
||||
- The number of new tokens generated per second over the past 5
|
||||
seconds
|
||||
- The prefix cache hit rate over the most recent 1k kv-cache block queries
|
||||
|
||||
### Metrics Publishing - Prometheus
|
||||
|
||||
The `PrometheusStatLogger` metrics publisher makes the metrics
|
||||
available via a `/metrics` HTTP endpoint in a Prometheus-compatible
|
||||
format. A Prometheus instance can then be configured to poll this
|
||||
endpoint (e.g. every second) and record the values in its time-series
|
||||
database. Prometheus is often used via Grafana, allowing these metrics
|
||||
to be graphed over time.
|
||||
|
||||
Prometheus supports the following metric types:
|
||||
|
||||
- Counter: a value that will increase over time, never reducing, and
|
||||
generally reset to zero when the vLLM instance restarts. For
|
||||
example, the number of tokens generated over the lifetime of the
|
||||
instance.
|
||||
- Gauge: a value that goes up and down, for example the number of
|
||||
requests currently scheduled for execution.
|
||||
- Histogram: a count of metric samples, recorded in buckets. For
|
||||
example, the number of requests whose TTFT was <1ms, <5ms, <10ms,
|
||||
<20ms, and so on.
|
||||
|
||||
Prometheus metrics can also be labelled, allowing metrics to be
|
||||
combined according to matching labels. In vLLM, we add a `model_name`
|
||||
label to every metric which includes the name of the model served by
|
||||
that instance.
|
||||
|
||||
Example output:
|
||||
|
||||
```bash
|
||||
$ curl http://0.0.0.0:8000/metrics
|
||||
# HELP vllm:num_requests_running Number of requests in model execution batches.
|
||||
# TYPE vllm:num_requests_running gauge
|
||||
vllm:num_requests_running{model_name="meta-llama/Llama-3.1-8B-Instruct"} 8.0
|
||||
...
|
||||
# HELP vllm:generation_tokens_total Number of generation tokens processed.
|
||||
# TYPE vllm:generation_tokens_total counter
|
||||
vllm:generation_tokens_total{model_name="meta-llama/Llama-3.1-8B-Instruct"} 27453.0
|
||||
...
|
||||
# HELP vllm:request_success_total Count of successfully processed requests.
|
||||
# TYPE vllm:request_success_total counter
|
||||
vllm:request_success_total{finished_reason="stop",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0
|
||||
vllm:request_success_total{finished_reason="length",model_name="meta-llama/Llama-3.1-8B-Instruct"} 131.0
|
||||
vllm:request_success_total{finished_reason="abort",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0
|
||||
...
|
||||
# HELP vllm:time_to_first_token_seconds Histogram of time to first token in seconds.
|
||||
# TYPE vllm:time_to_first_token_seconds histogram
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.001",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.005",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.01",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.02",model_name="meta-llama/Llama-3.1-8B-Instruct"} 13.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.04",model_name="meta-llama/Llama-3.1-8B-Instruct"} 97.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.06",model_name="meta-llama/Llama-3.1-8B-Instruct"} 123.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.08",model_name="meta-llama/Llama-3.1-8B-Instruct"} 138.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.1",model_name="meta-llama/Llama-3.1-8B-Instruct"} 140.0
|
||||
vllm:time_to_first_token_seconds_count{model_name="meta-llama/Llama-3.1-8B-Instruct"} 140.0
|
||||
```
|
||||
|
||||
Note - the choice of histogram buckets to be most useful to users
|
||||
across a broad set of use cases is not straightforward and will
|
||||
require refinement over time.
|
||||
|
||||
### Cache Config Info
|
||||
|
||||
`prometheus_client` has support for [Info
|
||||
metrics](https://prometheus.github.io/client_python/instrumenting/info/)
|
||||
which are equivalent to a `Gauge` whose value is permanently set to 1,
|
||||
but exposes interesting key/value pair information via labels. This is
|
||||
used for information about an instance that does not change - so it
|
||||
only needs to be observed at startup - and allows comparing across
|
||||
instances in Prometheus.
|
||||
|
||||
We use this concept for the `vllm:cache_config_info` metric:
|
||||
|
||||
```
|
||||
# HELP vllm:cache_config_info Information of the LLMEngine CacheConfig
|
||||
# TYPE vllm:cache_config_info gauge
|
||||
vllm:cache_config_info{block_size="16",cache_dtype="auto",calculate_kv_scales="False",cpu_offload_gb="0",enable_prefix_caching="False",gpu_memory_utilization="0.9",...} 1.0
|
||||
|
||||
```
|
||||
|
||||
However, `prometheus_client` has [never supported Info metrics in
|
||||
multiprocessing
|
||||
mode](https://github.com/prometheus/client_python/pull/300) - for
|
||||
[unclear
|
||||
reasons](gh-pr:7279#discussion_r1710417152). We
|
||||
simply use a `Gauge` metric set to 1 and
|
||||
`multiprocess_mode="mostrecent"` instead.
|
||||
|
||||
### LoRA Metrics
|
||||
|
||||
The `vllm:lora_requests_info` `Gauge` is somewhat similar, except the
|
||||
value is the current wall-clock time, and is updated every iteration.
|
||||
|
||||
The label names used are:
|
||||
|
||||
- `running_lora_adapters`: a per-adapter count of the number requests
|
||||
running using that adapter, formatted as a comma-separated string.
|
||||
- `waiting_lora_adapters`: similar, except counting requests that are
|
||||
waiting to be scheduled.
|
||||
- `max_lora` - the static "max number of LoRAs in a single batch."
|
||||
configuration.
|
||||
|
||||
Encoding a running/waiting counts for multiple adapters in a
|
||||
comma-separated string seems quite misguided - we could use labels to
|
||||
distinguish between per-adapter counts. This should be revisited.
|
||||
|
||||
Note that `multiprocess_mode="livemostrecent"` is used - the most
|
||||
recent metric is used, but only from currently running processes.
|
||||
|
||||
This was added in
|
||||
<gh-pr:9477> and there is
|
||||
[at least one known
|
||||
user](https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/54). If
|
||||
we revisit this design and deprecate the old metric, we should reduce
|
||||
the need for a significant deprecation period by making the change in
|
||||
v0 also and asking this project to move to the new metric.
|
||||
|
||||
### Prefix Cache metrics
|
||||
|
||||
The discussion in <gh-issue:10582> about adding prefix cache metrics yielded
|
||||
some interesting points which may be relevant to how we approach
|
||||
future metrics.
|
||||
|
||||
Every time the prefix cache is queried, we record the number of tokens
|
||||
queried and the number of queried tokens present in the cache
|
||||
(i.e. hits).
|
||||
|
||||
However, the metric of interest is the hit rate - i.e. the number of
|
||||
hits per query.
|
||||
|
||||
In the case of logging, we expect the user is best served by
|
||||
calculating the hit rate over a fixed number of the most recent
|
||||
queries (the interval is fixed to 1k most recent queries for now).
|
||||
|
||||
In the case of Prometheus though, we should take advantage of the
|
||||
time-series nature of Prometheus and allow the user to calculate the
|
||||
hit rate over an interval of their choosing. For example, a PromQL
|
||||
query to calculate the hit interval of the past 5 minutes:
|
||||
|
||||
```text
|
||||
rate(cache_query_hit[5m]) / rate(cache_query_total[5m])
|
||||
```
|
||||
|
||||
To achieve this, we should record the queries and hits as counters in
|
||||
Prometheus, rather than recording the hit rate as a gauge.
|
||||
|
||||
## Deprecated Metrics
|
||||
|
||||
### How To Deprecate
|
||||
|
||||
Deprecating metrics shouldn't be taken lightly. Users may not notice a
|
||||
metric has been deprecated, and may be quite inconvenienced when it is
|
||||
suddenly (from their perspective) when it is removed, even if there is
|
||||
an equivalent metric for them to use.
|
||||
|
||||
As an example, see how `vllm:avg_prompt_throughput_toks_per_s` was
|
||||
[deprecated](gh-pr:2764) (with a
|
||||
comment in the code),
|
||||
[removed](gh-pr:12383), and then
|
||||
[noticed by a
|
||||
user](gh-issue:13218).
|
||||
|
||||
In general:
|
||||
|
||||
1) We should be cautious about deprecating metrics, especially since
|
||||
it can be hard to predict the user impact.
|
||||
2) We should include a prominent deprecation notice in the help string
|
||||
that is included in the `/metrics' output.
|
||||
3) We should list deprecated metrics in user-facing documentation and
|
||||
release notes.
|
||||
4) We should consider hiding deprecated metrics behind a CLI argument
|
||||
in order to give administrators [an escape
|
||||
hatch](https://kubernetes.io/docs/concepts/cluster-administration/system-metrics/#show-hidden-metrics)
|
||||
for some time before deleting them.
|
||||
|
||||
See the [deprecation policy](../../contributing/deprecation_policy.md) for
|
||||
the project-wide deprecation policy.
|
||||
|
||||
### Unimplemented - `vllm:tokens_total`
|
||||
|
||||
Added by <gh-pr:4464>, but apparently never implemented. This can just be
|
||||
removed.
|
||||
|
||||
### Duplicated - Queue Time
|
||||
|
||||
The `vllm:time_in_queue_requests` Histogram metric was added by
|
||||
<gh-pr:9659> and its calculation is:
|
||||
|
||||
```
|
||||
self.metrics.first_scheduled_time = now
|
||||
self.metrics.time_in_queue = now - self.metrics.arrival_time
|
||||
```
|
||||
|
||||
Two weeks later, <gh-pr:4464> added `vllm:request_queue_time_seconds` leaving
|
||||
us with:
|
||||
|
||||
```
|
||||
if seq_group.is_finished():
|
||||
if (seq_group.metrics.first_scheduled_time is not None and
|
||||
seq_group.metrics.first_token_time is not None):
|
||||
time_queue_requests.append(
|
||||
seq_group.metrics.first_scheduled_time -
|
||||
seq_group.metrics.arrival_time)
|
||||
...
|
||||
if seq_group.metrics.time_in_queue is not None:
|
||||
time_in_queue_requests.append(
|
||||
seq_group.metrics.time_in_queue)
|
||||
```
|
||||
|
||||
This seems duplicative, and one of them should be removed. The latter
|
||||
is used by the Grafana dashboard, so we should deprecate or remove the
|
||||
former from v0.
|
||||
|
||||
### Prefix Cache Hit Rate
|
||||
|
||||
See above - we now expose 'queries' and 'hits' counters rather than a
|
||||
'hit rate' gauge.
|
||||
|
||||
### KV Cache Offloading
|
||||
|
||||
Two v0 metrics relate to a "swapped" preemption mode that is no
|
||||
longer relevant in v1:
|
||||
|
||||
- `vllm:num_requests_swapped`
|
||||
- `vllm:cpu_cache_usage_perc`
|
||||
|
||||
In this mode, when a request is preempted (e.g. to make room in KV
|
||||
cache to complete other requests), we swap kv cache blocks out to CPU
|
||||
memory. This is also known as "KV cache offloading" and is configured
|
||||
with `--swap-space` and `--preemption-mode`.
|
||||
|
||||
In v0, [vLLM has long supported beam
|
||||
search](gh-issue:6226). The
|
||||
SequenceGroup encapsulated the idea of N Sequences which
|
||||
all shared the same prompt kv blocks. This enabled KV cache block
|
||||
sharing between requests, and copy-on-write to do branching. CPU
|
||||
swapping was intended for these beam search like cases.
|
||||
|
||||
Later, the concept of prefix caching was introduced, which allowed KV
|
||||
cache blocks to be shared implicitly. This proved to be a better
|
||||
option than CPU swapping since blocks can be evicted slowly on demand
|
||||
and the part of the prompt that was evicted can be recomputed.
|
||||
|
||||
SequenceGroup was removed in V1, although a replacement will be
|
||||
required for "parallel sampling" (`n>1`). [Beam search was moved out of
|
||||
the core (in
|
||||
V0)](gh-issue:8306). There was a
|
||||
lot of complex code for a very uncommon feature.
|
||||
|
||||
In V1, with prefix caching being better (zero over head) and therefore
|
||||
on by default, the preemption and recompute strategy should work
|
||||
better.
|
||||
|
||||
## Future Work
|
||||
|
||||
### Parallel Sampling
|
||||
|
||||
Some v0 metrics are only relevant in the context of "parallel
|
||||
sampling". This is where the `n` parameter in a request is used to
|
||||
request multiple completions from the same prompt.
|
||||
|
||||
As part of adding parallel sampling support in <gh-pr:10980> we should
|
||||
also add these metrics.
|
||||
|
||||
- `vllm:request_params_n` (Histogram)
|
||||
|
||||
Observes the value of the 'n' parameter of every finished request.
|
||||
|
||||
- `vllm:request_max_num_generation_tokens` (Histogram)
|
||||
|
||||
Observes the maximum output length of all sequences in every finished
|
||||
sequence group. In the absence of parallel sampling, this is
|
||||
equivalent to `vllm:request_generation_tokens`.
|
||||
|
||||
### Speculative Decoding
|
||||
|
||||
Some v0 metrics are specific to "speculative decoding". This is where
|
||||
we generate candidate tokens using a faster, approximate method or
|
||||
model and then validate those tokens with the larger model.
|
||||
|
||||
- `vllm:spec_decode_draft_acceptance_rate` (Gauge)
|
||||
- `vllm:spec_decode_efficiency` (Gauge)
|
||||
- `vllm:spec_decode_num_accepted_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_draft_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_emitted_tokens_total` (Counter)
|
||||
|
||||
There is a PR under review (<gh-pr:12193>) to add "prompt lookup (ngram)"
|
||||
seculative decoding to v1. Other techniques will follow. We should
|
||||
revisit the v0 metrics in this context.
|
||||
|
||||
Note - we should probably expose acceptance rate as separate accepted
|
||||
and draft counters, like we do for prefix caching hit rate. Efficiency
|
||||
likely also needs similar treatment.
|
||||
|
||||
### Autoscaling and Load-balancing
|
||||
|
||||
A common use case for our metrics is to support automated scaling of
|
||||
vLLM instances.
|
||||
|
||||
For related discussion from the [Kubernetes Serving Working
|
||||
Group](https://github.com/kubernetes/community/tree/master/wg-serving),
|
||||
see:
|
||||
|
||||
- [Standardizing Large Model Server Metrics in
|
||||
Kubernetes](https://docs.google.com/document/d/1SpSp1E6moa4HSrJnS4x3NpLuj88sMXr2tbofKlzTZpk)
|
||||
- [Benchmarking LLM Workloads for Performance Evaluation and
|
||||
Autoscaling in
|
||||
Kubernetes](https://docs.google.com/document/d/1k4Q4X14hW4vftElIuYGDu5KDe2LtV1XammoG-Xi3bbQ)
|
||||
- [Inference
|
||||
Perf](https://github.com/kubernetes-sigs/wg-serving/tree/main/proposals/013-inference-perf)
|
||||
- <gh-issue:5041> and <gh-pr:12726>.
|
||||
|
||||
This is a non-trivial topic. Consider this comment from Rob:
|
||||
|
||||
> I think this metric should focus on trying to estimate what the max
|
||||
> concurrency that will cause the average request length > queries per
|
||||
> second ... since this is really what will "saturate" the server.
|
||||
|
||||
A clear goal is that we should expose the metrics required to detect
|
||||
this saturation point, so administrators can implement auto-scaling
|
||||
rules based on those. However, in order to do so, we need to have a
|
||||
clear view on how an administrator (and automated monitoring system)
|
||||
should judge an instance as approaching saturation:
|
||||
|
||||
> To identify, what is the saturation point for model server compute
|
||||
> (the inflection point where we cannot get more throughput with a
|
||||
> higher request rate, but start to incur additional latency) so we
|
||||
> can autoscale effectively?
|
||||
|
||||
### Metric Naming
|
||||
|
||||
Our approach to naming metrics probably deserves to be revisited:
|
||||
|
||||
1. The use of colons in metric names seems contrary to ["colons are
|
||||
reserved for user defined recording
|
||||
rules"](https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels)
|
||||
2. Most of our metrics follow the convention of ending with units, but
|
||||
not all do.
|
||||
3. Some of our metric names end with `_total`:
|
||||
|
||||
```
|
||||
If there is a suffix of `_total` on the metric name, it will be removed. When
|
||||
exposing the time series for counter, a `_total` suffix will be added. This is
|
||||
for compatibility between OpenMetrics and the Prometheus text format, as OpenMetrics
|
||||
requires the `_total` suffix.
|
||||
```
|
||||
|
||||
### Adding More Metrics
|
||||
|
||||
There is no shortage of ideas for new metrics:
|
||||
|
||||
- Examples from other projects like
|
||||
[TGI](https://github.com/IBM/text-generation-inference?tab=readme-ov-file#metrics)
|
||||
- Proposals arising from specific use cases, like the Kubernetes
|
||||
auto-scaling topic above
|
||||
- Proposals that might arise out of standardisation efforts like
|
||||
[OpenTelemetry Semantic Conventions for Gen
|
||||
AI](https://github.com/open-telemetry/semantic-conventions/tree/main/docs/gen-ai).
|
||||
|
||||
We should be cautious in our approach to adding new metrics. While
|
||||
metrics are often relatively straightforward to add:
|
||||
|
||||
1. They can be difficult to remove - see the section on deprecation
|
||||
above.
|
||||
2. They can have a meaningful performance impact when enabled. And
|
||||
metrics are usually of very limited use unless they can be enabled
|
||||
by default and in production.
|
||||
3. They have an impact on development and maintenance of the
|
||||
project. Every metric added to v0 has made this v1 effort more
|
||||
time-consuming, and perhaps not all metrics justify this ongoing
|
||||
investment in their maintenance.
|
||||
|
||||
## Tracing - OpenTelemetry
|
||||
|
||||
Metrics provide an aggregated view over time of the system's
|
||||
performance and health. Tracing, on the other hand, tracks individual
|
||||
requests as they move through different services and components. Both
|
||||
fall under the more general heading of "Observability".
|
||||
|
||||
v0 has support for OpenTelemetry tracing:
|
||||
|
||||
- Added by <gh-pr:4687>
|
||||
- Configured with `--oltp-traces-endpoint` and
|
||||
`--collect-detailed-traces`
|
||||
- [OpenTelemetry blog
|
||||
post](https://opentelemetry.io/blog/2024/llm-observability/)
|
||||
- [User-facing
|
||||
docs](https://docs.vllm.ai/en/latest/getting_started/examples/opentelemetry.html)
|
||||
- [Blog
|
||||
post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f)
|
||||
- [IBM product
|
||||
docs](https://www.ibm.com/docs/en/instana-observability/current?topic=mgaa-monitoring-large-language-models-llms-vllm-public-preview)
|
||||
|
||||
OpenTelemetry has a [Gen AI Working
|
||||
Group](https://github.com/open-telemetry/community/blob/main/projects/gen-ai.md).
|
||||
|
||||
Since metrics is a big enough topic on its own, we are going to tackle
|
||||
the topic of tracing in v1 separately.
|
||||
|
||||
### OpenTelemetry Model Forward vs Execute Time
|
||||
|
||||
In v0, we have the following two metrics:
|
||||
|
||||
- `vllm:model_forward_time_milliseconds` (Histogram) - The time spent
|
||||
in the model forward pass when this request was in the batch.
|
||||
- `vllm:model_execute_time_milliseconds` (Histogram) - The time spent
|
||||
in the model execute function. This will include model forward,
|
||||
block/sync across workers, cpu-gpu sync time and sampling time.
|
||||
|
||||
These metrics are only enabled when OpenTelemetry tracing is enabled
|
||||
and if `--collect-detailed-traces=all/model/worker` is used. The
|
||||
documentation for this option states:
|
||||
|
||||
> collect detailed traces for the specified "modules. This involves
|
||||
> use of possibly costly and or blocking operations and hence might
|
||||
> have a performance impact.
|
||||
|
||||
The metrics were added by <gh-pr:7089> and who up in an OpenTelemetry trace
|
||||
as:
|
||||
|
||||
```
|
||||
-> gen_ai.latency.time_in_scheduler: Double(0.017550230026245117)
|
||||
-> gen_ai.latency.time_in_model_forward: Double(3.151565277099609)
|
||||
-> gen_ai.latency.time_in_model_execute: Double(3.6468167304992676)
|
||||
```
|
||||
|
||||
We already have `inference_time` and `decode_time` metrics, so the
|
||||
question is whether there are sufficiently common use cases for the
|
||||
higher-resolution timings to justify the overhead.
|
||||
|
||||
Since we are going to treat the question of OpenTelemetry support
|
||||
separately, we will include these particular metrics under that topic.
|
||||
231
docs/design/v1/prefix_caching.md
Normal file
231
docs/design/v1/prefix_caching.md
Normal file
@@ -0,0 +1,231 @@
|
||||
# Automatic Prefix Caching
|
||||
|
||||
Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations. The core idea is simple – we cache the kv-cache blocks of processed requests, and reuse these blocks when a new request comes in with the same prefix as previous requests. Since prefix caching is almost a free lunch and won’t change model outputs, it has been widely used by many public endpoints (e.g., OpenAI, Anthropic, etc) and most open source LLM inference frameworks (e.g., SGLang).
|
||||
|
||||
While there are many ways to implement prefix caching, vLLM chooses a hash-based approach. Specifically, we hash each kv-cache block by the tokens in the block and the tokens in the prefix before the block:
|
||||
|
||||
```text
|
||||
Block 1 Block 2 Block 3
|
||||
[A gentle breeze stirred] [the leaves as children] [laughed in the distance]
|
||||
Block 1: |<--- block tokens ---->|
|
||||
Block 2: |<------- prefix ------>| |<--- block tokens --->|
|
||||
Block 3: |<------------------ prefix -------------------->| |<--- block tokens ---->|
|
||||
```
|
||||
|
||||
In the example above, the KV cache in the first block can be uniquely identified with the token “A gentle breeze stirred”. The third block can be uniquely identified with the tokens in the block “laughed in the distance”, along with the prefix tokens “A gentle breeze stirred the leaves as children”. Therefore, we can build the block hash of `hash(tuple[components])`, where components are:
|
||||
|
||||
* Parent hash value: The hash value of the parent hash block.
|
||||
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
|
||||
* Extra hashes: Other values required to make this block unique, such as LoRA IDs, multi-modality input hashes (see the example below), and cache salts to isolate caches in multi-tenant environments.
|
||||
|
||||
> **Note 1:** We only cache full blocks.
|
||||
|
||||
> **Note 2:** The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value. To avoid any hash collisions **in a multi-tenant setup, we advise to use SHA256** as hash function instead of the default builtin hash.
|
||||
SHA256 is supported since vLLM v0.8.3 and must be enabled with a command line argument. It comes with a performance impact of about 100-200ns per token (~6ms for 50k tokens of context).
|
||||
|
||||
**A hashing example with multi-modality inputs**
|
||||
In this example, we illustrate how prefix caching works with multi-modality inputs (e.g., images). Assuming we have a request with the following messages:
|
||||
|
||||
```text
|
||||
messages = [
|
||||
{"role": "user",
|
||||
"content": [
|
||||
{"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
{"type": "image_url",
|
||||
"image_url": {"url": image_url},
|
||||
},
|
||||
]},
|
||||
]
|
||||
```
|
||||
|
||||
It will become the following prompt:
|
||||
|
||||
```text
|
||||
Prompt:
|
||||
<s>[INST]What's in this image?\n[IMG][/INST]
|
||||
|
||||
Tokenized prompt:
|
||||
[1, 3, 7493, 1681, 1294, 1593, 3937, 9551, 10, 4]
|
||||
|
||||
Prompt with placeholders (<P>):
|
||||
[1, 3, 7493, 1681, 1294, 1593, 3937, 9551, <P>, <P>, ..., <P>, 4]
|
||||
```
|
||||
|
||||
As we can see, after the tokenization, the `[IMG]` will be replaced by a sequence of placeholder tokens, and these placeholders will be replaced by image embeddings during prefill. The challenge for prefix caching to support this case is we need to differentiate images from the placeholders. To address this problem, we encode the image hash generated by the frontend image processor. For example, the hash of the blocks in the above prompt would be (assuming block size 16, and we have 41 placeholder tokens):
|
||||
|
||||
```text
|
||||
Block 0
|
||||
Parent hash: None
|
||||
Token IDs: 1, 3, 7493, 1681, 1294, 1593, 3937, 9551, <p>, ..., <p>
|
||||
Extra hash: <image hash>
|
||||
Block 1
|
||||
Parent hash: Block 0 hash
|
||||
Token IDs: <p>, ..., <p>
|
||||
Extra hash: <image hash>
|
||||
Block 2
|
||||
Parent hash: Block 1 hash
|
||||
Token IDs: <p>, ..., <p>
|
||||
Extra hash: <image hash>
|
||||
Block 3
|
||||
Parent hash: Block 2 hash
|
||||
Token IDs: <p>, ..., <p>, 4
|
||||
Extra hash: <image hash>
|
||||
```
|
||||
|
||||
In the rest of this document, we first introduce the data structure used for prefix caching in vLLM v1, followed by the prefix caching workflow of major KV cache operators (e.g., allocate, append, free, eviction). Finally, we use an example to illustrate the end to end prefix caching workflow.
|
||||
|
||||
**Cache Isolation for Security**
|
||||
To improve privacy in shared environments, vLLM supports isolating prefix cache reuse through optional per-request salting. By including a `cache_salt` in the request, this value is injected into the hash of the first block, ensuring that only requests with the same salt can reuse cached KV blocks. This prevents timing-based attacks where an adversary could infer cached content by observing latency differences. This offers protection without compromising performance.
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Here is a document with details about the world series: ..."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"}
|
||||
],
|
||||
"cache_salt": "your-cache-salt"
|
||||
}
|
||||
```
|
||||
|
||||
With this setup, cache sharing is limited to users or requests that explicitly agree on a common salt, enabling cache reuse within a trust group while isolating others.
|
||||
|
||||
> **Note:** Cache isolation is not supported in engine V0.
|
||||
|
||||
## Data Structure
|
||||
|
||||
The prefix caching in vLLM v1 is implemented in the KV cache manager. The basic building block is the “Block” data class (simplified):
|
||||
|
||||
```python
|
||||
class KVCacheBlock:
|
||||
# The block ID (immutable)
|
||||
block_id: int
|
||||
# The block hash (will be assigned when the block is full,
|
||||
# and will be reset when the block is evicted).
|
||||
block_hash: BlockHashType
|
||||
# The number of requests using this block now.
|
||||
ref_cnt: int
|
||||
|
||||
# The pointers to form a doubly linked list for the free queue.
|
||||
prev_free_block: Optional["KVCacheBlock"] = None
|
||||
next_free_block: Optional["KVCacheBlock"] = None
|
||||
```
|
||||
|
||||
There are two design points to highlight:
|
||||
|
||||
1. We allocate all KVCacheBlock when initializing the KV cache manager to be a block pool. This avoids Python object creation overheads and can easily track all blocks all the time.
|
||||
2. We introduce doubly linked list pointers directly in the KVCacheBlock, so that we could construct a free queue directly. This gives us two benefits:
|
||||
1. We could have O(1) complexity moving elements in the middle to the tail.
|
||||
2. We could avoid introducing another Python queue (e.g., `deque`) which has a wrapper to the elements.
|
||||
|
||||
As a result, we will have the following components when the KV cache manager is initialized:
|
||||
|
||||

|
||||
|
||||
* Block Pool: A list of KVCacheBlock.
|
||||
* Free Block Queue: Only store the pointers of head and tail blocks for manipulations.
|
||||
* Cache blocks: Mapping from hash key to block IDs.
|
||||
* Request blocks: Mapping from request ID to allocated block IDs.
|
||||
|
||||
## Operations
|
||||
|
||||
### Block Allocation
|
||||
|
||||
**New request:** Workflow for the scheduler to schedule a new request with KV cache block allocation:
|
||||
|
||||
1. The scheduler calls `kv_cache_manager.get_computed_blocks()` to get a sequence of blocks that have already been computed. This is done by hashing the prompt tokens in the request and looking up Cache Blocks.
|
||||
2. The scheduler calls `kv_cache_manager.allocate_slots()`. It does the following steps:
|
||||
1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate.
|
||||
2. “Touch” the computed blocks. It increases the reference count of the computed block by one, and removes the block from the free queue if the block wasn’t used by other requests. This is to avoid these computed blocks being evicted. See the example in the next section for illustration.
|
||||
3. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on.
|
||||
4. If an allocated block is already full of tokens, we immediately add it to the Cache Block, so that the block can be reused by other requests in the same batch.
|
||||
|
||||
**Running request:** Workflow for the scheduler to schedule a running request with KV cache block allocation:
|
||||
|
||||
1. The scheduler calls `kv_cache_manager.append_slots()`. It does the following steps:
|
||||
1. Compute the number of new required blocks, and return if there are no sufficient blocks to allocate.
|
||||
2. Allocate new blocks by popping the heads of the free queue. If the head block is a cached block, this also “evicts” the block so that no other requests can reuse it anymore from now on.
|
||||
3. Append token IDs to the slots in existing blocks as well as the new blocks. If a block is full, we add it to the Cache Block to cache it.
|
||||
|
||||
**Duplicated blocks**
|
||||
Assuming block size is 4 and you send a request (Request 1\) with prompt ABCDEF and decoding length 3:
|
||||
|
||||
```text
|
||||
Prompt: [A, B, C, D, E, F]
|
||||
Output: [G, H, I]
|
||||
|
||||
Time 0:
|
||||
Tokens: [A, B, C, D, E, F, G]
|
||||
Block Table: [0 (ABCD), 1 (EFG)]
|
||||
Cache Blocks: 0
|
||||
Time 1:
|
||||
Tokens: [A, B, C, D, E, F, G, H]
|
||||
Block Table: [0 (ABCD), 1 (EFGH)]
|
||||
Cache Blocks: 0, 1
|
||||
Time 2:
|
||||
Tokens: [A, B, C, D, E, F, G, H, I]
|
||||
Block Table: [0 (ABCD), 1 (EFGH), 2 (I)]
|
||||
Cache Blocks: 0, 1
|
||||
```
|
||||
|
||||
Now block 0 and block 1 are cached, and we send the same request again (Request 2\) with greedy sampling, so that it will produce exactly the same outputs as the Request 1:
|
||||
|
||||
```text
|
||||
Prompt: [A, B, C, D, E, F]
|
||||
Output: [G, H, I]
|
||||
|
||||
Time 0:
|
||||
Tokens: [A, B, C, D, E, F, G]
|
||||
Block Table: [0 (ABCD), 3 (EFG)]
|
||||
Cache Blocks: 0, 1
|
||||
Time 1:
|
||||
Tokens: [A, B, C, D, E, F, G, H]
|
||||
Block Table: [0 (ABCD), 3 (EFGH)]
|
||||
Cache Blocks: 0, 1, 3
|
||||
```
|
||||
|
||||
As can be seen, block 3 is a new full block and is cached. However, it is redundant as block 1, meaning that we cached the same block twice. In v0, when detecting block 3 is duplicated, we free block 3 and let Request 2 use block 1 instead, so its block table becomes `[0, 1]` in Time 1. However, the block table in vLLM v1 is append-only, meaning that changing the block table from `[0, 3]` to `[0, 1]` is not allowed. As a result, we will have duplicated blocks for the hash key E-H. This duplication will be eliminated when the request is freed.
|
||||
|
||||
### Free
|
||||
|
||||
When a request is finished, we free all its blocks if no other requests are using them (reference count = 0). In this example, we free request 1 and block 2, 3, 4, 8 associated with it. We can see that the freed blocks are added to the tail of the free queue in the *reverse* order. This is because the last block of a request must hash more tokens and is less likely to be reused by other requests. As a result, it should be evicted first.
|
||||
|
||||

|
||||
|
||||
### Eviction (LRU)
|
||||
|
||||
When the head block (least recently used block) of the free queue is cached, we have to evict the block to prevent it from being used by other requests. Specifically, eviction involves the following steps:
|
||||
|
||||
1. Pop the block from the head of the free queue. This is the LRU block to be evicted.
|
||||
2. Remove the block ID from the Cache Block.
|
||||
3. Remove the block hash.
|
||||
|
||||
## Example
|
||||
|
||||
In this example, we assume the block size is 4 (each block can cache 4 tokens), and we have 10 blocks in the KV-cache manager in total.
|
||||
|
||||
**Time 1: The cache is empty and a new request comes in.** We allocate 4 blocks. 3 of them are already full and cached. The fourth block is partially full with 3 of 4 tokens.
|
||||
|
||||

|
||||
|
||||
**Time 3: Request 0 makes the block 3 full and asks for a new block to keep decoding.** We cache block 3 and allocate block 4.
|
||||
|
||||

|
||||
|
||||
**Time 4: Request 1 comes in with the 14 prompt tokens, where the first 10 tokens are the same as request 0.** We can see that only the first 2 blocks (8 tokens) hit the cache, because the 3rd block only matches 2 of 4 tokens.
|
||||
|
||||

|
||||
|
||||
**Time 5: Request 0 is finished and free.** Blocks 2, 3 and 4 are added to the free queue in the reverse order (but block 2 and 3 are still cached). Block 0 and 1 are not added to the free queue because they are being used by Request 1.
|
||||
|
||||

|
||||
|
||||
**Time 6: Request 1 is finished and free.**
|
||||
|
||||

|
||||
|
||||
**Time 7: Request 2 comes in with the 29 prompt tokens, where the first 12 tokens are the same as request 0\.** Note that even the block order in the free queue was `7 - 8 - 9 - 4 - 3 - 2 - 6 - 5 - 1 - 0`, the cache hit blocks (i.e., 0, 1, 2) are touched and removed from the queue before allocation, so the free queue becomes `7 - 8 - 9 - 4 - 3 - 6 - 5`. As a result, the allocated blocks are 0 (cached), 1 (cached), 2 (cached), 7, 8, 9, 4, 3 (evicted).
|
||||
|
||||

|
||||
149
docs/design/v1/torch_compile.md
Normal file
149
docs/design/v1/torch_compile.md
Normal file
@@ -0,0 +1,149 @@
|
||||
# vLLM's `torch.compile` integration
|
||||
|
||||
In vLLM's V1 architecture, `torch.compile` is enabled by default and is a critical part of the framework. This document gives a simple walk-through example to show how to understand the `torch.compile` usage.
|
||||
|
||||
Throughout the example, we will run a common Llama model using v1, and turn on debug level logging to show all the details. The command to be used is `VLLM_USE_V1=1 VLLM_LOGGING_LEVEL=DEBUG vllm serve meta-llama/Llama-3.2-1B`.
|
||||
|
||||
## Compilation Cache
|
||||
|
||||
In the very verbose logs, we can see:
|
||||
|
||||
```
|
||||
INFO 03-07 03:06:55 [backends.py:409] Using cache directory: ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0 for vLLM's torch.compile
|
||||
```
|
||||
|
||||
vLLM will take all the available factors into consideration, and decide a directory to store all the compilation artifact. This means, you can directly copy the whole `~/.cache/vllm/torch_compile_cache` directory in your deployment scenario to save a great amount of compilation time, and hence accelerating the starting time of the vLLM instance.
|
||||
|
||||
The factors considered include:
|
||||
|
||||
- All the related configs (see the `compute_hash` functions in the [config.py](gh-file:vllm/config.py))
|
||||
- PyTorch configs (see the `compute_hash` functions in the [compiler_interface.py](gh-file:vllm/compilation/compiler_interface.py))
|
||||
- The model's forward function and the relevant functions called by the forward function (see below)
|
||||
|
||||
With all these factors taken into consideration, usually we can guarantee that the cache is safe to use, and will not cause any unexpected behavior. Therefore, the cache is enabled by default. If you want to debug the compilation process, or if you suspect the cache is causing some issues, you can disable it by setting the environment variable `VLLM_DISABLE_COMPILE_CACHE=1`.
|
||||
|
||||
A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all the compilation finishes before we serve any requests. No requests will trigger new compilations. Otherwise, the engine would be blocked on that request, and the response time will have unexpected spikes.
|
||||
|
||||
## Python Code Compilation
|
||||
|
||||
In the very verbose logs, we can see:
|
||||
|
||||
```
|
||||
DEBUG 03-07 03:06:52 [decorators.py:203] Start compiling function <code object forward at 0x7f08acf40c90, file "xxx/vllm/model_executor/models/llama.py", line 339>
|
||||
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] Traced files (to be considered for compilation cache):
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/_dynamo/polyfills/builtins.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/nn/modules/container.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/torch/nn/modules/module.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/attention/layer.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/distributed/communication_op.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/distributed/parallel_state.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/custom_op.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/activation.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/layernorm.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/linear.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/rotary_embedding.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/layers/vocab_parallel_embedding.py
|
||||
DEBUG 03-07 03:06:54 [backends.py:370] xxx/vllm/model_executor/models/llama.py
|
||||
|
||||
DEBUG 03-07 03:07:07 [backends.py:462] Computation graph saved to ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/computation_graph.py
|
||||
DEBUG 03-07 03:07:07 [wrapper.py:105] Dynamo transformed code saved to ~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/transformed_code.py
|
||||
```
|
||||
|
||||
This is about the Python code compilation, i.e. graph capture by Dynamo. It tries to trace the function with code `xxx/vllm/model_executor/models/llama.py:339`, which is the `forward` function of the model we compile. During the forward pass, there are also other functions called and inlined by Dynamo, as shown by the logs, including some PyTorch functions from `xxx/torch/nn/modules/module.py` (used by PyTorch `nn.Module`, because module attribute access will trigger a function call), some communication / attention / activation functions from vLLM. All the traced files will be considered when we decide the cache directory to use. This way, any code change in the above files will trigger compilation cache miss, and therefore recompilation.
|
||||
|
||||
The result of the Dynamo compilation, is a new function stored in `~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/transformed_code.py`. Usually, this function unpacks tensors from the module, and then pass it to the traced computation graph. The computation graph is stored in `~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/computation_graph.py`.
|
||||
|
||||
## Computation Graph Processing
|
||||
|
||||
The computation graph has shape annotations for every tensor. The inputs are input ids, position ids, weights and buffers from the model, and the outputs are the final hidden states. Note that lm head projection and sampling operations are not considered in the graph.
|
||||
|
||||
Most of the inputs to the computation graph has static shape, since they are model weights and buffers, and will not change during the lifetime of the model. Only the input ids and position ids have symbolic shapes, i.e. the shape can change from batch to batch. However, they will share the same symbolic shapes. That is to say, the only changing size to the computation graph, is the batch size (number of tokens processed in the current forward pass).
|
||||
|
||||
The attention operation is complicated, and it needs to interact with kv caches, with complicated shapes. Fortunately, the output of the attention operation just share the same shape as the input query of the attention operation. Therefore, we wrap the whole attention operation into a PyTorch custom op `torch.ops.vllm.unified_attention_with_output`, so that Dynamo will not try to inspect any of the internal operations. This way, although attention operation is complicated, we can still capture the model's computation graph as a full-graph, from Dynamo's perspective.
|
||||
|
||||
The computation graph is further split into pieces, by the `splitting_ops` (usually this is the attention operation). Therefore, in the `~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/computation_graph.py` file, we can see lots of submodules, each submodule is a piece of graph after splitting:
|
||||
|
||||
- Attention operation itself is a submodule.
|
||||
- The part of computation graph, from one attention operation to the next attention operation, is a submodule.
|
||||
|
||||
Every submodule can be identified by its index, and will be processed individually.
|
||||
|
||||
## Computation Graph Compilation
|
||||
|
||||
In the very verbose logs, we can also see:
|
||||
|
||||
```
|
||||
DEBUG 03-07 03:52:37 [backends.py:134] store the 0-th graph for shape None from inductor via handle ('fpegyiq3v3wzjzphd45wkflpabggdbjpylgr7tta4hj6uplstsiw', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/iw/ciwzrk3ittdqatuzwonnajywvno3llvjcs2vfdldzwzozn3zi3iy.py')
|
||||
DEBUG 03-07 03:52:39 [backends.py:134] store the 1-th graph for shape None from inductor via handle ('f7fmlodmf3h3by5iiu2c4zarwoxbg4eytwr3ujdd2jphl4pospfd', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/ly/clyfzxldfsj7ehaluis2mca2omqka4r7mgcedlf6xfjh645nw6k2.py')
|
||||
...
|
||||
DEBUG 03-07 03:52:45 [backends.py:134] store the 15-th graph for shape None from inductor via handle ('f7fmlodmf3h3by5iiu2c4zarwoxbg4eytwr3ujdd2jphl4pospfd', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/ly/clyfzxldfsj7ehaluis2mca2omqka4r7mgcedlf6xfjh645nw6k2.py')
|
||||
DEBUG 03-07 03:52:45 [backends.py:134] store the 16-th graph for shape None from inductor via handle ('fvj3ccoi7m34f3dnr4itmu55mmun44l5xymwhrjlwisylsk7q6jy', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/tf/ctfftkglj7b4lcttq5cymx6cew372uoauupqn6ldsvpiucavqcjc.py')
|
||||
```
|
||||
|
||||
This means the first piece of computation graph (with shape `None` for symbolic shape) is compiled by Inductor (with a key `fpegyiq3v3wzjzphd45wkflpabggdbjpylgr7tta4hj6uplstsiw`). The compiled kernel is stored in `~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/iw/ciwzrk3ittdqatuzwonnajywvno3llvjcs2vfdldzwzozn3zi3iy.py`. You can open the file to see what is the code Inductor finally runs.
|
||||
|
||||
One more detail: you can see that the 1-th graph and the 15-th graph have the same key, while the 0-th graph and the 16-th graph are different. This is expected, since we split the graph by the attention op, we get 3 unique subgraphs:
|
||||
|
||||
- the first layer before attention
|
||||
- every middle layer, from one attention operation to the next attention operation
|
||||
- the final layer after attention
|
||||
|
||||
If we already have the cache directory (e.g. run the same code for the second time), we will see the following logs:
|
||||
|
||||
```
|
||||
DEBUG 03-07 04:00:45 [backends.py:86] Directly load the 0-th graph for shape None from inductor via handle ('fpegyiq3v3wzjzphd45wkflpabggdbjpylgr7tta4hj6uplstsiw', '~/.cache/vllm/torch_compile_cache/1517964802/rank_0_0/inductor_cache/iw/ciwzrk3ittdqatuzwonnajywvno3llvjcs2vfdldzwzozn3zi3iy.py')
|
||||
```
|
||||
|
||||
This time, Inductor compilation is completely bypassed, and we will load from disk to read the compilation artifact we get from the last time.
|
||||
|
||||
The above example just uses Inductor to compile for a general shape (i.e. symbolic shape). We can also use Inductor to compile for some of the specific shapes, for example:
|
||||
|
||||
```
|
||||
vllm serve meta-llama/Llama-3.2-1B --compilation_config '{"compile_sizes": [1, 2, 4, 8]}'
|
||||
```
|
||||
|
||||
Then it will also compile a specific kernel just for batch size `1, 2, 4, 8`. At this time, all of the shapes in the computation graph are static and known, and we will turn on auto-tuning to tune for max performance. This can be slow when you run it for the first time, but the next time you run it, we can directly bypass the tuning and run the tuned kernel.
|
||||
|
||||
When all the shapes are known, `torch.compile` can compare different configs, and often find some better configs to run the kernel. For example, we can see the following log:
|
||||
|
||||
```
|
||||
AUTOTUNE mm(8x2048, 2048x3072)
|
||||
triton_mm_4 0.0130 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
|
||||
triton_mm_8 0.0134 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
|
||||
triton_mm_12 0.0148 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
|
||||
mm 0.0160 ms 81.6%
|
||||
triton_mm_16 0.0165 ms 78.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
|
||||
triton_mm_3 0.0199 ms 65.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
|
||||
triton_mm_1 0.0203 ms 64.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2
|
||||
triton_mm_7 0.0203 ms 64.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
|
||||
triton_mm_2 0.0208 ms 62.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
|
||||
triton_mm_11 0.0215 ms 60.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
|
||||
SingleProcess AUTOTUNE benchmarking takes 2.0428 seconds and 7.5727 seconds precompiling
|
||||
```
|
||||
|
||||
It means, for a matrix multiplication with shape `8x2048x3072`, `torch.compile` tries triton template with various configs, and it is much faster than the default code (which dispatches to cublas library).
|
||||
|
||||
Unfortunately, because auto-tuning takes quite a long time (from seconds to minutes, depending on the model size and the batch size), even though it can be cached for later use, for the sake of user-friendliness, we turn it off by default. If you want to have max performance, it is recommended to try it, by compiling specific shapes.
|
||||
|
||||
## Cudagraph Capture
|
||||
|
||||
vLLM's V1 architecture uses piecewise cudagraph. The full computation graph is split as mentioned above, and we only capture the cudagraph for the piece of graph between attention operations (including the first graph before any attention operation, and the last graph after all the attention operation). This is based on a common observation: computation between attentions are usually token-wise and easy to deal with for cudagraph; while the attention operation is non-trivial to be cudagraph compatible. Thus, by running the attention operation in eager mode while the rest operations in cudagraph, we keep the flexibility of the attention operation.
|
||||
|
||||
The piecewise cudagraph also has fine-grained memory management. The purpose is to only exclude the attention kernel from cudagraph, while keeping all the rest modules and the memory allocation operations in the cudagraph. This is why the attention operation in V1 has the output tensor as the input of the attention.
|
||||
|
||||
The cudagraphs are captured and managed by the compiler backend, and replayed when the batch size has corresponding cudagraph captured. The caller of the model (model runner) only needs to make sure it manages the input buffers correctly. All of the intermediate buffers are managed automatically by the compiler backend.
|
||||
|
||||
By default, vLLM will try to determine a set of sizes to capture cudagraph. You can also override it using the config `cudagraph_capture_sizes`:
|
||||
|
||||
```
|
||||
vllm serve meta-llama/Llama-3.2-1B --compilation-config '{"cudagraph_capture_sizes": [1, 2, 4, 8]}'
|
||||
```
|
||||
|
||||
Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture.
|
||||
|
||||
### Full Cudagraph capture
|
||||
|
||||
It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config '{"full_cuda_graph": true}'`.
|
||||
|
||||
Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled.
|
||||
Reference in New Issue
Block a user