[Config] Clean up SchedulerConfig initialization (#28665)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-11-14 22:41:02 +08:00
committed by GitHub
parent 96b23b8e3b
commit 511a6b611d
9 changed files with 182 additions and 163 deletions

View File

@@ -4,7 +4,7 @@
import hashlib
from collections.abc import Callable
from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
@@ -12,11 +12,6 @@ from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import (
DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
@@ -33,25 +28,32 @@ SchedulerPolicy = Literal["fcfs", "priority"]
class SchedulerConfig:
"""Scheduler configuration."""
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""
max_num_batched_tokens: int = Field(default=None, ge=1)
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
"""Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_num_seqs: int = Field(default=None, ge=1)
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
"""Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
max_model_len: int = Field(default=None, ge=1)
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
max_model_len: int = Field(default=8192, ge=1)
"""Maximum length of a sequence (including prompt and generated text).
The default value here is mainly for convenience when testing.
In real usage, this should duplicate `ModelConfig.max_model_len` via
`EngineArgs`."""
max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be
@@ -76,9 +78,13 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""
enable_chunked_prefill: bool = Field(default=None)
enable_chunked_prefill: bool = True
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
on the remaining `max_num_batched_tokens`.
The default value here is mainly for convenience when testing.
In real usage, this should be set in `EngineArgs.create_engine_config`.
"""
is_multimodal_model: bool = False
"""True if the model is multimodal."""
@@ -111,9 +117,6 @@ class SchedulerConfig:
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""
chunked_prefill_enabled: bool = Field(init=False)
"""True if chunked prefill is enabled."""
disable_chunked_mm_input: bool = False
"""If set to true and chunked prefill is enabled, we do not want to
partially schedule a multimodal item. Only used in V1
@@ -188,15 +191,7 @@ class SchedulerConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator(
"max_num_batched_tokens",
"max_num_seqs",
"max_model_len",
"enable_chunked_prefill",
"scheduler_cls",
"async_scheduling",
mode="wrap",
)
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
@@ -205,16 +200,9 @@ class SchedulerConfig:
return handler(value)
def __post_init__(self, is_encoder_decoder: bool) -> None:
if self.max_model_len is None:
self.max_model_len = 8192
if self.max_num_seqs is None:
self.max_num_seqs = 128
if is_encoder_decoder:
# Chunked prefill should be disabled for encoder-decoder models.
self.disable_chunked_mm_input = True
self.chunked_prefill_enabled = False
self.enable_chunked_prefill = False
self.long_prefill_token_threshold = 0
logger.info(
@@ -222,37 +210,6 @@ class SchedulerConfig:
" prefix caching; disabling both."
)
if self.max_num_batched_tokens is None:
if self.enable_chunked_prefill:
self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
else:
# If max_model_len is too short, use
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# for higher throughput.
self.max_num_batched_tokens = max(
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
)
if self.runner_type == "pooling":
# Choose specific value for higher throughput
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)
# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
)
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
self.encoder_cache_size = self.max_num_batched_tokens
@@ -262,7 +219,6 @@ class SchedulerConfig:
self.max_num_batched_tokens,
)
self.chunked_prefill_enabled = self.enable_chunked_prefill
if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
@@ -276,6 +232,14 @@ class SchedulerConfig:
self.long_prefill_token_threshold,
)
@property
def chunked_prefill_enabled(self) -> bool:
return self.enable_chunked_prefill
@chunked_prefill_enabled.setter
def chunked_prefill_enabled(self, value: bool):
self.enable_chunked_prefill = value
@model_validator(mode="after")
def _verify_args(self) -> Self:
if (