[Perf][CLI] Improve overall startup time (#19941)

This commit is contained in:
Aaron Pham
2025-06-22 19:11:22 -04:00
committed by GitHub
parent 33d51f599e
commit c4cf260677
14 changed files with 293 additions and 103 deletions

View File

@@ -28,7 +28,7 @@ from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
from typing_extensions import deprecated, runtime_checkable
from typing_extensions import Self, deprecated, runtime_checkable
import vllm.envs as envs
from vllm import version
@@ -1537,7 +1537,6 @@ class CacheConfig:
def __post_init__(self) -> None:
self.swap_space_bytes = self.swap_space * GiB_bytes
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
@@ -1546,7 +1545,8 @@ class CacheConfig:
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.cpu_offload_gb < 0:
raise ValueError("CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")
@@ -1556,6 +1556,8 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
return self
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
@@ -1942,15 +1944,14 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"
self._verify_args()
@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type)
and self.distributed_executor_backend.uses_ray)
def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
@@ -1977,8 +1978,7 @@ class ParallelConfig:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
assert isinstance(self.worker_extension_cls, str), (
"worker_extension_cls must be a string (qualified class name).")
return self
PreemptionMode = Literal["swap", "recompute"]
@@ -2202,9 +2202,8 @@ class SchedulerConfig:
self.max_num_partial_prefills, self.max_long_partial_prefills,
self.long_prefill_token_threshold)
self._verify_args()
def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
if (self.max_num_batched_tokens < self.max_model_len
and not self.chunked_prefill_enabled):
raise ValueError(
@@ -2263,6 +2262,8 @@ class SchedulerConfig:
"must be greater than or equal to 1 and less than or equal to "
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
return self
@property
def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1
@@ -2669,8 +2670,6 @@ class SpeculativeConfig:
if self.posterior_alpha is None:
self.posterior_alpha = 0.3
self._verify_args()
@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: Optional[int],
@@ -2761,7 +2760,8 @@ class SpeculativeConfig:
return draft_parallel_config
def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
@@ -2812,6 +2812,8 @@ class SpeculativeConfig:
"Eagle3 is only supported for Llama models. "
f"Got {self.target_model_config.hf_text_config.model_type=}")
return self
@property
def num_lookahead_slots(self) -> int:
"""The number of additional slots the scheduler should allocate per