Refactor Prometheus and Add Request Level Metrics (#2316)

This commit is contained in:
Robert Shaw
2024-01-31 14:58:07 -08:00
committed by GitHub
parent d0d93b92b1
commit 93b38bea5d
7 changed files with 1230 additions and 98 deletions

View File

@@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, LoRAConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import record_metrics
from vllm.engine.metrics import StatLogger, Stats
from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
@@ -28,8 +28,7 @@ if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
_LOCAL_LOGGING_INTERVAL_SEC = 5
class LLMEngine:
@@ -116,12 +115,10 @@ class LLMEngine:
# Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Logging.
self.last_logging_time = 0.0
# List of (timestamp, num_tokens)
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
@@ -537,6 +534,7 @@ class LLMEngine:
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
@@ -732,10 +730,10 @@ class LLMEngine:
and not seq_group.prefix.computed):
seq_group.prefix.computed = True
# Log stats.
if self.log_stats:
# Log the system stats.
self._log_system_stats(scheduler_outputs.prompt_run,
scheduler_outputs.num_batched_tokens)
self.stat_logger.log(self._get_stats(scheduler_outputs))
return request_outputs
def step(self) -> List[RequestOutput]:
@@ -810,81 +808,73 @@ class LLMEngine:
return self._process_model_outputs(output, scheduler_outputs)
def do_log_stats(self) -> None:
self._log_system_stats(False, 0)
"""Forced log when no requests active."""
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
def _log_system_stats(
self,
prompt_run: bool,
num_batched_tokens: int,
) -> None:
def _get_stats(self,
scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
"""Get Stats to be Logged to Prometheus."""
now = time.monotonic()
# Log the number of batched input tokens.
if prompt_run:
self.num_prompt_tokens.append((now, num_batched_tokens))
else:
self.num_generation_tokens.append((now, num_batched_tokens))
should_log = now - self.last_logging_time >= _LOGGING_INTERVAL_SEC
if not should_log:
return
# KV Cache Usage in %.
num_total_gpu = self.cache_config.num_gpu_blocks
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
# Discard the old stats.
self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
if now - t < _LOGGING_INTERVAL_SEC]
self.num_generation_tokens = [(t, n)
for t, n in self.num_generation_tokens
if now - t < _LOGGING_INTERVAL_SEC]
num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage = 0.
if num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
)
cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
if len(self.num_prompt_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
window = now - self.num_prompt_tokens[0][0]
avg_prompt_throughput = total_num_tokens / window
else:
avg_prompt_throughput = 0.0
if len(self.num_generation_tokens) > 1:
total_num_tokens = sum(n
for _, n in self.num_generation_tokens[:-1])
window = now - self.num_generation_tokens[0][0]
avg_generation_throughput = total_num_tokens / window
else:
avg_generation_throughput = 0.0
# Scheduler State
num_running = len(self.scheduler.running)
num_swapped = len(self.scheduler.swapped)
num_waiting = len(self.scheduler.waiting)
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = (
self.scheduler.block_manager.get_num_free_gpu_blocks())
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
# Iteration stats if we have scheduler output.
num_prompt_tokens = 0
num_generation_tokens = 0
time_to_first_tokens = []
time_per_output_tokens = []
time_e2e_requests = []
if scheduler_outputs is not None:
prompt_run = scheduler_outputs.prompt_run
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = (
self.scheduler.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
# Number of Tokens.
if prompt_run:
num_prompt_tokens = scheduler_outputs.num_batched_tokens
else:
num_generation_tokens = scheduler_outputs.num_batched_tokens
record_metrics(
avg_prompt_throughput=avg_prompt_throughput,
avg_generation_throughput=avg_generation_throughput,
scheduler_running=len(self.scheduler.running),
scheduler_swapped=len(self.scheduler.swapped),
scheduler_waiting=len(self.scheduler.waiting),
# Latency Timings.
time_last_iters = []
for seq_group in scheduler_outputs.scheduled_seq_groups:
# Time since last token. (n.b. updates seq_group.last_token_time)
time_last_iters.append(seq_group.get_last_latency(now))
# Time since arrival for all finished requests.
if seq_group.is_finished():
time_e2e_requests.append(now - seq_group.arrival_time)
time_to_first_tokens = time_last_iters if prompt_run else []
time_per_output_tokens = [] if prompt_run else time_last_iters
return Stats(
now=now,
num_running=num_running,
num_swapped=num_swapped,
num_waiting=num_waiting,
gpu_cache_usage=gpu_cache_usage,
cpu_cache_usage=cpu_cache_usage,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=num_generation_tokens,
time_to_first_tokens=time_to_first_tokens,
time_per_output_tokens=time_per_output_tokens,
time_e2e_requests=time_e2e_requests,
)
logger.info("Avg prompt throughput: "
f"{avg_prompt_throughput:.1f} tokens/s, "
"Avg generation throughput: "
f"{avg_generation_throughput:.1f} tokens/s, "
f"Running: {len(self.scheduler.running)} reqs, "
f"Swapped: {len(self.scheduler.swapped)} reqs, "
f"Pending: {len(self.scheduler.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset,