diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 33d486263..ef71a05d3 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -14,7 +14,7 @@ from datetime import datetime from enum import IntEnum from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Any, TypeVar, get_args +from typing import TYPE_CHECKING, Any, Literal, TypeVar, get_args import torch from pydantic import ConfigDict, Field, model_validator @@ -76,6 +76,8 @@ class OptimizationLevel(IntEnum): """O3: Currently the same as -O2s.""" +PerformanceMode = Literal["balanced", "interactivity", "throughput"] + IS_QUANTIZED = False IS_DENSE = False # The optimizations that depend on these properties currently set to False @@ -312,6 +314,13 @@ class VllmConfig: performance. -O2 is used by default. See OptimizationLevel for full description.""" + performance_mode: PerformanceMode = "balanced" + """Performance mode for runtime behavior, 'balanced' is the default. + 'interactivity' favors low end-to-end per-request latency at small batch + sizes (fine-grained CUDA graphs, latency-oriented kernels). + 'throughput' favors aggregate tokens/sec at high concurrency (larger CUDA + graphs, more aggressive batching, throughput-oriented kernels).""" + weight_transfer_config: WeightTransferConfig | None = None """The configurations for weight transfer during RL training.""" @@ -643,6 +652,11 @@ class VllmConfig: # To give each torch profile run a unique instance name. self.instance_id = f"{time.time_ns()}" + if self.performance_mode != "balanced": + logger.info_once( + "Performance mode set to '%s'.", self.performance_mode, scope="local" + ) + self.try_verify_and_update_config() if self.model_config is not None: @@ -1332,9 +1346,15 @@ class VllmConfig: # sort to make sure the sizes are in ascending order cudagraph_capture_sizes.sort() else: - cudagraph_capture_sizes = [ - i for i in [1, 2, 4] if i <= max_cudagraph_capture_size - ] + if self.performance_mode == "interactivity": + # Fine-grained CUDA graphs at small batch sizes + # for minimal padding overhead + interactivity_max = min(max_cudagraph_capture_size, 32) + cudagraph_capture_sizes = list(range(1, interactivity_max + 1)) + else: + cudagraph_capture_sizes = [ + i for i in [1, 2, 4] if i <= max_cudagraph_capture_size + ] if max_cudagraph_capture_size >= 8: # Step size 8 for small batch sizes, up to 256(not included) cudagraph_capture_sizes += list( @@ -1345,6 +1365,8 @@ class VllmConfig: cudagraph_capture_sizes += list( range(256, max_cudagraph_capture_size + 1, 16) ) + # de-duplicate and sort the sizes + cudagraph_capture_sizes = sorted(set(cudagraph_capture_sizes)) if ( self.parallel_config.tensor_parallel_size > 1 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 15a662ba2..ca76454d6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -89,7 +89,7 @@ from vllm.config.parallel import ( ) from vllm.config.scheduler import SchedulerPolicy from vllm.config.utils import get_field -from vllm.config.vllm import OptimizationLevel +from vllm.config.vllm import OptimizationLevel, PerformanceMode from vllm.logger import init_logger, suppress_logging from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -596,6 +596,7 @@ class EngineArgs: kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill optimization_level: OptimizationLevel = VllmConfig.optimization_level + performance_mode: PerformanceMode = VllmConfig.performance_mode kv_offloading_size: float | None = CacheConfig.kv_offloading_size kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend @@ -1264,6 +1265,7 @@ class EngineArgs: vllm_group.add_argument( "--optimization-level", **vllm_kwargs["optimization_level"] ) + vllm_group.add_argument("--performance-mode", **vllm_kwargs["performance_mode"]) vllm_group.add_argument( "--weight-transfer-config", **vllm_kwargs["weight_transfer_config"] ) @@ -1894,6 +1896,7 @@ class EngineArgs: profiler_config=self.profiler_config, additional_config=self.additional_config, optimization_level=self.optimization_level, + performance_mode=self.performance_mode, weight_transfer_config=self.weight_transfer_config, ) @@ -2110,6 +2113,13 @@ class EngineArgs: SchedulerConfig.DEFAULT_MAX_NUM_SEQS, ) + # If throughput mode is set, double max_num_batched_tokens and max_num_seqs. + if self.performance_mode == "throughput": + if orig_max_num_batched_tokens is None: + self.max_num_batched_tokens *= 2 + if orig_max_num_seqs is None: + self.max_num_seqs *= 2 + if orig_max_num_batched_tokens is None: assert model_config.max_model_len is not None, ( "max_model_len must be set by this point"