Improve configs - SchedulerConfig (#16533)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-14 10:24:16 +01:00
committed by GitHub
parent dc1b4a6f13
commit e51929ebca
4 changed files with 279 additions and 218 deletions

View File

@@ -1,25 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# yapf: disable
import argparse
import dataclasses
import json
import re
import threading
from dataclasses import MISSING, dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union, cast, get_args, get_origin)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping,
Optional, Tuple, Type, TypeVar, Union, cast, get_args,
get_origin)
import torch
from typing_extensions import TypeIs
import vllm.envs as envs
from vllm import version
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig, HfOverrides,
DecodingConfig, DeviceConfig,
DistributedExecutorBackend, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig, TaskOption,
TokenizerPoolConfig, VllmConfig, get_attr_docs)
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerPoolConfig, VllmConfig,
get_attr_docs)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@@ -28,7 +33,9 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
from vllm.utils import FlexibleArgumentParser, is_in_ray_actor
# yapf: enable
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
@@ -47,11 +54,32 @@ DEVICE_OPTIONS = [
"hpu",
]
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
def nullable_str(val: str):
if not val or val == "None":
def optional_arg(val: str, return_type: type[T]) -> Optional[T]:
if val == "" or val == "None":
return None
return val
try:
return cast(Callable, return_type)(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
def optional_str(val: str) -> Optional[str]:
return optional_arg(val, str)
def optional_int(val: str) -> Optional[int]:
return optional_arg(val, int)
def optional_float(val: str) -> Optional[float]:
return optional_arg(val, float)
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
@@ -112,7 +140,8 @@ class EngineArgs:
# is intended for expert use only. The API may change without
# notice.
distributed_executor_backend: Optional[Union[
str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
DistributedExecutorBackend,
Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
@@ -129,11 +158,13 @@ class EngineArgs:
swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_partial_prefills: Optional[int] = 1
max_long_partial_prefills: Optional[int] = 1
long_prefill_token_threshold: Optional[int] = 0
max_num_seqs: Optional[int] = None
max_num_batched_tokens: Optional[
int] = SchedulerConfig.max_num_batched_tokens
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
long_prefill_token_threshold: int = \
SchedulerConfig.long_prefill_token_threshold
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False
revision: Optional[str] = None
@@ -169,20 +200,21 @@ class EngineArgs:
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = True
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: Optional[
dict] = LoadConfig.model_loader_extra_config
ignore_patterns: Optional[Union[str,
List[str]]] = LoadConfig.ignore_patterns
preemption_mode: Optional[str] = None
preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: Optional[bool] = None
disable_chunked_mm_input: bool = False
scheduler_delay_factor: float = SchedulerConfig.delay_factor
enable_chunked_prefill: Optional[
bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
logits_processor_pattern: Optional[str] = None
@@ -194,8 +226,8 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None
@@ -236,15 +268,33 @@ class EngineArgs:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""
def is_type_in_union(cls: type[Any], type: type[Any]) -> bool:
def is_type_in_union(cls: TypeHint, type: TypeHint) -> bool:
"""Check if the class is a type in a union type."""
return get_origin(cls) is Union and type in get_args(cls)
is_union = get_origin(cls) is Union
type_in_union = type in [get_origin(a) or a for a in get_args(cls)]
return is_union and type_in_union
def is_optional(cls: type[Any]) -> bool:
def get_type_from_union(cls: TypeHint, type: TypeHintT) -> TypeHintT:
"""Get the type in a union type."""
for arg in get_args(cls):
if (get_origin(arg) or arg) is type:
return arg
raise ValueError(f"Type {type} not found in union type {cls}.")
def is_optional(cls: TypeHint) -> TypeIs[Union[Any, None]]:
"""Check if the class is an optional type."""
return is_type_in_union(cls, type(None))
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
def can_be_type(cls: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
"""Check if the class can be of type."""
return cls is type or get_origin(cls) is type or is_type_in_union(
cls, type)
def is_custom_type(cls: TypeHint) -> bool:
"""Check if the class is a custom type."""
return cls.__module__ != "builtins"
def get_kwargs(cls: type[Any]) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
@@ -253,19 +303,41 @@ class EngineArgs:
default = (field.default_factory
if field.default is MISSING else field.default)
kwargs[name] = {"default": default, "help": cls_docs[name]}
# When using action="store_true"
# add_argument doesn't accept type
if field.type is bool:
continue
# Handle optional fields
if is_optional(field.type):
kwargs[name]["type"] = nullable_str
continue
# Handle str in union fields
if is_type_in_union(field.type, str):
kwargs[name]["type"] = str
continue
kwargs[name]["type"] = field.type
# Make note of if the field is optional and get the actual
# type of the field if it is
optional = is_optional(field.type)
field_type = get_args(
field.type)[0] if optional else field.type
if can_be_type(field_type, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
kwargs[name]["type"] = bool
elif can_be_type(field_type, Literal):
# Creates choices from Literal arguments
if is_type_in_union(field_type, Literal):
field_type = get_type_from_union(field_type, Literal)
choices = get_args(field_type)
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
f"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}"
)
kwargs[name]["type"] = choice_type
elif can_be_type(field_type, int):
kwargs[name]["type"] = optional_int if optional else int
elif can_be_type(field_type, float):
kwargs[name][
"type"] = optional_float if optional else float
elif (can_be_type(field_type, str)
or can_be_type(field_type, dict)
or is_custom_type(field_type)):
kwargs[name]["type"] = optional_str if optional else str
else:
raise ValueError(
f"Unsupported type {field.type} for argument {name}. ")
return kwargs
# Model arguments
@@ -285,13 +357,13 @@ class EngineArgs:
'which task to use.')
parser.add_argument(
'--tokenizer',
type=nullable_str,
type=optional_str,
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
"--hf-config-path",
type=nullable_str,
type=optional_str,
default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.')
@@ -303,21 +375,21 @@ class EngineArgs:
'the input. The generated output will contain token ids.')
parser.add_argument(
'--revision',
type=nullable_str,
type=optional_str,
default=None,
help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--code-revision',
type=nullable_str,
type=optional_str,
default=None,
help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-revision',
type=nullable_str,
type=optional_str,
default=None,
help='Revision of the huggingface tokenizer to use. '
'It can be a branch name, a tag name, or a commit id. '
@@ -357,7 +429,6 @@ class EngineArgs:
load_group.add_argument('--model-loader-extra-config',
**load_kwargs["model_loader_extra_config"])
load_group.add_argument('--use-tqdm-on-load',
action=argparse.BooleanOptionalAction,
**load_kwargs["use_tqdm_on_load"])
parser.add_argument(
@@ -413,7 +484,7 @@ class EngineArgs:
'the behavior is subject to change in each release.')
parser.add_argument(
'--logits-processor-pattern',
type=nullable_str,
type=optional_str,
default=None,
help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` '
@@ -439,7 +510,6 @@ class EngineArgs:
)
parallel_group.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp', 'uni', 'external_launcher'],
**parallel_kwargs["distributed_executor_backend"])
parallel_group.add_argument(
'--pipeline-parallel-size', '-pp',
@@ -450,18 +520,15 @@ class EngineArgs:
**parallel_kwargs["data_parallel_size"])
parallel_group.add_argument(
'--enable-expert-parallel',
action='store_true',
**parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument(
'--max-parallel-loading-workers',
**parallel_kwargs["max_parallel_loading_workers"])
parallel_group.add_argument(
'--ray-workers-use-nsight',
action='store_true',
**parallel_kwargs["ray_workers_use_nsight"])
parallel_group.add_argument(
'--disable-custom-all-reduce',
action='store_true',
**parallel_kwargs["disable_custom_all_reduce"])
# KV cache arguments
parser.add_argument('--block-size',
@@ -502,14 +569,6 @@ class EngineArgs:
'block manager v2) is now the default. '
'Setting this flag to True or False'
' has no effect on vLLM behavior.')
parser.add_argument(
'--num-lookahead-slots',
type=int,
default=EngineArgs.num_lookahead_slots,
help='Experimental scheduling config necessary for '
'speculative decoding. This will be replaced by '
'speculative config in the future; it is present '
'to enable correctness tests until then.')
parser.add_argument('--seed',
type=int,
@@ -552,36 +611,6 @@ class EngineArgs:
default=None,
help='If specified, ignore GPU profiling result and use this number'
' of GPU blocks. Used for testing preemption.')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='Maximum number of batched tokens per '
'iteration.')
parser.add_argument(
"--max-num-partial-prefills",
type=int,
default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, the max number of concurrent \
partial prefills.")
parser.add_argument(
"--max-long-partial-prefills",
type=int,
default=EngineArgs.max_long_partial_prefills,
help="For chunked prefill, the maximum number of prompts longer "
"than --long-prefill-token-threshold that will be prefilled "
"concurrently. Setting this less than --max-num-partial-prefills "
"will allow shorter prompts to jump the queue in front of longer "
"prompts in some cases, improving latency.")
parser.add_argument(
"--long-prefill-token-threshold",
type=float,
default=EngineArgs.long_prefill_token_threshold,
help="For chunked prefill, a request is considered long if the "
"prompt is longer than this number of tokens.")
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='Maximum number of sequences per iteration.')
parser.add_argument(
'--max-logprobs',
type=int,
@@ -594,7 +623,7 @@ class EngineArgs:
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=nullable_str,
type=optional_str,
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
@@ -658,7 +687,7 @@ class EngineArgs:
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config',
type=nullable_str,
type=optional_str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
@@ -721,7 +750,7 @@ class EngineArgs:
'base model dtype.'))
parser.add_argument(
'--long-lora-scaling-factors',
type=nullable_str,
type=optional_str,
default=EngineArgs.long_lora_scaling_factors,
help=('Specify multiple scaling factors (which can '
'be different from base model scaling factor '
@@ -766,28 +795,6 @@ class EngineArgs:
help=('Maximum number of forward steps per '
'scheduler call.'))
parser.add_argument(
'--multi-step-stream-outputs',
action=StoreBoolean,
default=EngineArgs.multi_step_stream_outputs,
nargs="?",
const="True",
help='If False, then multi-step will stream outputs at the end '
'of all steps')
parser.add_argument(
'--scheduler-delay-factor',
type=float,
default=EngineArgs.scheduler_delay_factor,
help='Apply a delay (of delay factor multiplied by previous '
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
action=StoreBoolean,
default=EngineArgs.enable_chunked_prefill,
nargs="?",
const="True",
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument('--speculative-config',
type=json.loads,
default=None,
@@ -863,22 +870,43 @@ class EngineArgs:
help="Disable async output processing. This may result in "
"lower performance.")
parser.add_argument(
'--scheduling-policy',
choices=['fcfs', 'priority'],
default="fcfs",
help='The scheduling policy to use. "fcfs" (first come first served'
', i.e. requests are handled in order of arrival; default) '
'or "priority" (requests are handled based on given '
'priority (lower value means earlier handling) and time of '
'arrival deciding any ties).')
parser.add_argument(
'--scheduler-cls',
default=EngineArgs.scheduler_cls,
help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
'is the default scheduler. Can be a class directly or the path to '
'a class of form "mod.custom_class".')
# Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig)
scheduler_group = parser.add_argument_group(
title="SchedulerConfig",
description=SchedulerConfig.__doc__,
)
scheduler_group.add_argument(
'--max-num-batched-tokens',
**scheduler_kwargs["max_num_batched_tokens"])
scheduler_group.add_argument('--max-num-seqs',
**scheduler_kwargs["max_num_seqs"])
scheduler_group.add_argument(
"--max-num-partial-prefills",
**scheduler_kwargs["max_num_partial_prefills"])
scheduler_group.add_argument(
"--max-long-partial-prefills",
**scheduler_kwargs["max_long_partial_prefills"])
scheduler_group.add_argument(
"--long-prefill-token-threshold",
**scheduler_kwargs["long_prefill_token_threshold"])
scheduler_group.add_argument('--num-lookahead-slots',
**scheduler_kwargs["num_lookahead_slots"])
scheduler_group.add_argument('--scheduler-delay-factor',
**scheduler_kwargs["delay_factor"])
scheduler_group.add_argument(
'--enable-chunked-prefill',
**scheduler_kwargs["enable_chunked_prefill"])
scheduler_group.add_argument(
'--multi-step-stream-outputs',
**scheduler_kwargs["multi_step_stream_outputs"])
scheduler_group.add_argument('--scheduling-policy',
**scheduler_kwargs["policy"])
scheduler_group.add_argument(
"--disable-chunked-mm-input",
**scheduler_kwargs["disable_chunked_mm_input"])
parser.add_argument('--scheduler-cls',
**scheduler_kwargs["scheduler_cls"])
parser.add_argument(
'--override-neuron-config',
@@ -930,7 +958,7 @@ class EngineArgs:
'class without changing the existing functions.')
parser.add_argument(
"--generation-config",
type=nullable_str,
type=optional_str,
default="auto",
help="The folder path to the generation config. "
"Defaults to 'auto', the generation config will be loaded from "
@@ -1003,20 +1031,6 @@ class EngineArgs:
"Note that even if this is set to False, cascade attention will be "
"only used when the heuristic tells that it's beneficial.")
parser.add_argument(
"--disable-chunked-mm-input",
action=StoreBoolean,
default=EngineArgs.disable_chunked_mm_input,
nargs="?",
const="True",
help="Disable multimodal input chunking attention for V1. "
"If set to true and chunked prefill is enabled, we do not want to"
" partially schedule a multimodal item. This ensures that if a "
"request has a mixed prompt (like text tokens TTTT followed by "
"image tokens IIIIIIIIII) where only some image tokens can be "
"scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled "
"as TTTT in one step and IIIIIIIIII in the next.")
return parser
@classmethod
@@ -1370,7 +1384,7 @@ class EngineArgs:
recommend_to_remove=False)
return False
if self.preemption_mode != EngineArgs.preemption_mode:
if self.preemption_mode != SchedulerConfig.preemption_mode:
_raise_or_fallback(feature_name="--preemption-mode",
recommend_to_remove=True)
return False
@@ -1381,17 +1395,17 @@ class EngineArgs:
recommend_to_remove=True)
return False
if self.scheduling_policy != EngineArgs.scheduling_policy:
if self.scheduling_policy != SchedulerConfig.policy:
_raise_or_fallback(feature_name="--scheduling-policy",
recommend_to_remove=False)
return False
if self.num_scheduler_steps != EngineArgs.num_scheduler_steps:
if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
_raise_or_fallback(feature_name="--num-scheduler-steps",
recommend_to_remove=True)
return False
if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor:
if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
_raise_or_fallback(feature_name="--scheduler-delay-factor",
recommend_to_remove=True)
return False
@@ -1475,9 +1489,9 @@ class EngineArgs:
# No Concurrent Partial Prefills so far.
if (self.max_num_partial_prefills
!= EngineArgs.max_num_partial_prefills
!= SchedulerConfig.max_num_partial_prefills
or self.max_long_partial_prefills
!= EngineArgs.max_long_partial_prefills):
!= SchedulerConfig.max_long_partial_prefills):
_raise_or_fallback(feature_name="Concurrent Partial Prefill",
recommend_to_remove=False)
return False