Add more Prometheus metrics (#2764)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
@@ -22,7 +22,8 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupMetadata)
|
||||
SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
@@ -217,7 +218,8 @@ class LLMEngine:
|
||||
if self.log_stats:
|
||||
self.stat_logger = StatLogger(
|
||||
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
||||
labels=dict(model_name=model_config.model))
|
||||
labels=dict(model_name=model_config.model),
|
||||
max_model_len=self.model_config.max_model_len)
|
||||
self.stat_logger.info("cache_config", self.cache_config)
|
||||
|
||||
# Create sequence output processor, e.g. for beam search or
|
||||
@@ -619,59 +621,109 @@ class LLMEngine:
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
# KV Cache Usage in %.
|
||||
# System State
|
||||
# Scheduler State
|
||||
num_running_sys = len(self.scheduler.running)
|
||||
num_swapped_sys = len(self.scheduler.swapped)
|
||||
num_waiting_sys = len(self.scheduler.waiting)
|
||||
|
||||
# KV Cache Usage in %
|
||||
num_total_gpu = self.cache_config.num_gpu_blocks
|
||||
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
|
||||
gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
|
||||
gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
|
||||
|
||||
num_total_cpu = self.cache_config.num_cpu_blocks
|
||||
cpu_cache_usage = 0.
|
||||
cpu_cache_usage_sys = 0.
|
||||
if num_total_cpu > 0:
|
||||
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
|
||||
)
|
||||
cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
|
||||
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
|
||||
|
||||
# Scheduler State
|
||||
num_running = len(self.scheduler.running)
|
||||
num_swapped = len(self.scheduler.swapped)
|
||||
num_waiting = len(self.scheduler.waiting)
|
||||
# Iteration stats
|
||||
num_prompt_tokens_iter = 0
|
||||
num_generation_tokens_iter = 0
|
||||
time_to_first_tokens_iter: List[float] = []
|
||||
time_per_output_tokens_iter: List[float] = []
|
||||
|
||||
# Iteration stats if we have scheduler output.
|
||||
num_prompt_tokens = 0
|
||||
num_generation_tokens = 0
|
||||
time_to_first_tokens = []
|
||||
time_per_output_tokens = []
|
||||
time_e2e_requests = []
|
||||
# Request stats
|
||||
# Latency
|
||||
time_e2e_requests: List[float] = []
|
||||
# Metadata
|
||||
num_prompt_tokens_requests: List[int] = []
|
||||
num_generation_tokens_requests: List[int] = []
|
||||
best_of_requests: List[int] = []
|
||||
n_requests: List[int] = []
|
||||
finished_reason_requests: List[str] = []
|
||||
|
||||
# NOTE: This loop assumes prefill seq_groups are before
|
||||
# decode seq_groups in scheduled_seq_groups.
|
||||
if scheduler_outputs is not None:
|
||||
prompt_run = scheduler_outputs.num_prefill_groups > 0
|
||||
num_generation_tokens_from_prefill_groups = 0.
|
||||
if scheduler_outputs.num_prefill_groups > 0 and len(
|
||||
scheduler_outputs.scheduled_seq_groups
|
||||
) != scheduler_outputs.num_prefill_groups:
|
||||
print("DETECTED CHUNKED")
|
||||
|
||||
# Number of Tokens.
|
||||
if prompt_run:
|
||||
num_prompt_tokens = sum(
|
||||
len(scheduled_seq_group.seq_group.prompt_token_ids)
|
||||
for scheduled_seq_group in
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
num_generation_tokens = sum(
|
||||
scheduled_seq_group.seq_group.num_seqs()
|
||||
for scheduled_seq_group in
|
||||
scheduler_outputs.scheduled_seq_groups)
|
||||
else:
|
||||
num_generation_tokens = scheduler_outputs.num_batched_tokens
|
||||
|
||||
# Latency Timings.
|
||||
time_last_iters = []
|
||||
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
for idx, scheduled_seq_group in enumerate(
|
||||
scheduler_outputs.scheduled_seq_groups):
|
||||
group_was_prefill = idx < scheduler_outputs.num_prefill_groups
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
# Time since last token.
|
||||
# (n.b. updates seq_group.metrics.last_token_time)
|
||||
time_last_iters.append(seq_group.get_last_latency(now))
|
||||
# Time since arrival for all finished requests.
|
||||
|
||||
# NOTE: a seq_group that completed all of its prefill tokens
|
||||
# in the last iteration will have seq_group.is_prefill() = False
|
||||
# with group_was_prefill = True
|
||||
if group_was_prefill:
|
||||
# Number of prompt tokens.
|
||||
num_prompt_tokens_iter += (
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
|
||||
# If the seq_group just finished the prefill state
|
||||
# get TTFT.
|
||||
if not seq_group.is_prefill():
|
||||
latency = seq_group.get_last_latency(now)
|
||||
time_to_first_tokens_iter.append(latency)
|
||||
|
||||
# One generation token per finished prefill.
|
||||
num_generation_tokens_from_prefill_groups += (
|
||||
seq_group.num_seqs())
|
||||
else:
|
||||
# TPOTs.
|
||||
latency = seq_group.get_last_latency(now)
|
||||
time_per_output_tokens_iter.append(latency)
|
||||
|
||||
# Because of chunked prefill, we can have a single sequence
|
||||
# group that does multiple prompt_runs. To prevent logging
|
||||
# the same metadata more than once per request, we standardize
|
||||
# on logging request level information for finished requests,
|
||||
# which can only happen once.
|
||||
if seq_group.is_finished():
|
||||
# Latency timings
|
||||
time_e2e_requests.append(now -
|
||||
seq_group.metrics.arrival_time)
|
||||
|
||||
time_to_first_tokens = time_last_iters if prompt_run else []
|
||||
time_per_output_tokens = [] if prompt_run else time_last_iters
|
||||
# Metadata
|
||||
num_prompt_tokens_requests.append(
|
||||
len(seq_group.prompt_token_ids))
|
||||
num_generation_tokens_requests.extend([
|
||||
seq.get_output_len()
|
||||
for seq in seq_group.get_finished_seqs()
|
||||
])
|
||||
best_of_requests.append(seq_group.sampling_params.best_of)
|
||||
n_requests.append(seq_group.sampling_params.n)
|
||||
finished_reason_requests.extend([
|
||||
SequenceStatus.get_finished_reason(seq.status)
|
||||
for seq in seq_group.get_finished_seqs()
|
||||
])
|
||||
|
||||
# Number of generation tokens.
|
||||
# num_batched_tokens equals the number of prompt_tokens plus the
|
||||
# number of decode_tokens in a single iteration. So,
|
||||
# num_generation_tokens = num_batched_tokens - num_prompt_tokens
|
||||
# + num_generation_tokens_from_prefill_groups (since we generate
|
||||
# one token on prefills on iters where the prefill finishes).
|
||||
num_generation_tokens_iter = (
|
||||
scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
|
||||
num_generation_tokens_from_prefill_groups)
|
||||
|
||||
# Spec decode, if enabled, emits specialized metrics from the worker in
|
||||
# sampler output.
|
||||
@@ -683,17 +735,32 @@ class LLMEngine:
|
||||
|
||||
return Stats(
|
||||
now=now,
|
||||
num_running=num_running,
|
||||
num_swapped=num_swapped,
|
||||
num_waiting=num_waiting,
|
||||
gpu_cache_usage=gpu_cache_usage,
|
||||
cpu_cache_usage=cpu_cache_usage,
|
||||
num_prompt_tokens=num_prompt_tokens,
|
||||
num_generation_tokens=num_generation_tokens,
|
||||
time_to_first_tokens=time_to_first_tokens,
|
||||
time_per_output_tokens=time_per_output_tokens,
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
|
||||
# System stats
|
||||
# Scheduler State
|
||||
num_running_sys=num_running_sys,
|
||||
num_swapped_sys=num_swapped_sys,
|
||||
num_waiting_sys=num_waiting_sys,
|
||||
# KV Cache Usage in %
|
||||
gpu_cache_usage_sys=gpu_cache_usage_sys,
|
||||
cpu_cache_usage_sys=cpu_cache_usage_sys,
|
||||
|
||||
# Iteration stats
|
||||
num_prompt_tokens_iter=num_prompt_tokens_iter,
|
||||
num_generation_tokens_iter=num_generation_tokens_iter,
|
||||
time_to_first_tokens_iter=time_to_first_tokens_iter,
|
||||
time_per_output_tokens_iter=time_per_output_tokens_iter,
|
||||
spec_decode_metrics=spec_decode_metrics,
|
||||
|
||||
# Request stats
|
||||
# Latency
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
# Metadata
|
||||
num_prompt_tokens_requests=num_prompt_tokens_requests,
|
||||
num_generation_tokens_requests=num_generation_tokens_requests,
|
||||
best_of_requests=best_of_requests,
|
||||
n_requests=n_requests,
|
||||
finished_reason_requests=finished_reason_requests,
|
||||
)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user