[Feature] Pydantic validation for scheduler.py and structured_outputs.py (#26519)
Signed-off-by: Vinay Damodaran <vrdn@hey.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
9e5bd3076e
commit
5e8862e9e0
@@ -2,10 +2,11 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from dataclasses import InitVar, field
|
from collections.abc import Callable
|
||||||
|
from dataclasses import InitVar
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import SkipValidation, model_validator
|
from pydantic import Field, field_validator, model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
@@ -31,28 +32,28 @@ class SchedulerConfig:
|
|||||||
runner_type: RunnerType = "generate"
|
runner_type: RunnerType = "generate"
|
||||||
"""The runner type to launch for the model."""
|
"""The runner type to launch for the model."""
|
||||||
|
|
||||||
max_num_batched_tokens: SkipValidation[int] = None # type: ignore
|
max_num_batched_tokens: int = Field(default=None, ge=1)
|
||||||
"""Maximum number of tokens to be processed in a single iteration.
|
"""Maximum number of tokens to be processed in a single iteration.
|
||||||
|
|
||||||
This config has no static default. If left unspecified by the user, it will
|
This config has no static default. If left unspecified by the user, it will
|
||||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||||
|
|
||||||
max_num_seqs: SkipValidation[int] = None # type: ignore
|
max_num_seqs: int = Field(default=None, ge=1)
|
||||||
"""Maximum number of sequences to be processed in a single iteration.
|
"""Maximum number of sequences to be processed in a single iteration.
|
||||||
|
|
||||||
This config has no static default. If left unspecified by the user, it will
|
This config has no static default. If left unspecified by the user, it will
|
||||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||||
|
|
||||||
max_model_len: SkipValidation[int] = None # type: ignore
|
max_model_len: int = Field(default=None, ge=1)
|
||||||
"""Maximum length of a sequence (including prompt and generated text). This
|
"""Maximum length of a sequence (including prompt and generated text). This
|
||||||
is primarily set in `ModelConfig` and that value should be manually
|
is primarily set in `ModelConfig` and that value should be manually
|
||||||
duplicated here."""
|
duplicated here."""
|
||||||
|
|
||||||
max_num_partial_prefills: int = 1
|
max_num_partial_prefills: int = Field(default=1, ge=1)
|
||||||
"""For chunked prefill, the maximum number of sequences that can be
|
"""For chunked prefill, the maximum number of sequences that can be
|
||||||
partially prefilled concurrently."""
|
partially prefilled concurrently."""
|
||||||
|
|
||||||
max_long_partial_prefills: int = 1
|
max_long_partial_prefills: int = Field(default=1, ge=1)
|
||||||
"""For chunked prefill, the maximum number of prompts longer than
|
"""For chunked prefill, the maximum number of prompts longer than
|
||||||
long_prefill_token_threshold that will be prefilled concurrently. Setting
|
long_prefill_token_threshold that will be prefilled concurrently. Setting
|
||||||
this less than max_num_partial_prefills will allow shorter prompts to jump
|
this less than max_num_partial_prefills will allow shorter prompts to jump
|
||||||
@@ -62,7 +63,7 @@ class SchedulerConfig:
|
|||||||
"""For chunked prefill, a request is considered long if the prompt is
|
"""For chunked prefill, a request is considered long if the prompt is
|
||||||
longer than this number of tokens."""
|
longer than this number of tokens."""
|
||||||
|
|
||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = Field(default=0, ge=0)
|
||||||
"""The number of slots to allocate per sequence per
|
"""The number of slots to allocate per sequence per
|
||||||
step, beyond the known token ids. This is used in speculative
|
step, beyond the known token ids. This is used in speculative
|
||||||
decoding to store KV activations of tokens which may or may not be
|
decoding to store KV activations of tokens which may or may not be
|
||||||
@@ -71,7 +72,7 @@ class SchedulerConfig:
|
|||||||
NOTE: This will be replaced by speculative config in the future; it is
|
NOTE: This will be replaced by speculative config in the future; it is
|
||||||
present to enable correctness tests until then."""
|
present to enable correctness tests until then."""
|
||||||
|
|
||||||
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
|
enable_chunked_prefill: bool = Field(default=None)
|
||||||
"""If True, prefill requests can be chunked based
|
"""If True, prefill requests can be chunked based
|
||||||
on the remaining max_num_batched_tokens."""
|
on the remaining max_num_batched_tokens."""
|
||||||
|
|
||||||
@@ -86,14 +87,14 @@ class SchedulerConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO (ywang96): Make this configurable.
|
# TODO (ywang96): Make this configurable.
|
||||||
max_num_encoder_input_tokens: int = field(init=False)
|
max_num_encoder_input_tokens: int = Field(init=False)
|
||||||
"""Multimodal encoder compute budget, only used in V1.
|
"""Multimodal encoder compute budget, only used in V1.
|
||||||
|
|
||||||
NOTE: This is not currently configurable. It will be overridden by
|
NOTE: This is not currently configurable. It will be overridden by
|
||||||
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||||
|
|
||||||
# TODO (ywang96): Make this configurable.
|
# TODO (ywang96): Make this configurable.
|
||||||
encoder_cache_size: int = field(init=False)
|
encoder_cache_size: int = Field(init=False)
|
||||||
"""Multimodal encoder cache size, only used in V1.
|
"""Multimodal encoder cache size, only used in V1.
|
||||||
|
|
||||||
NOTE: This is not currently configurable. It will be overridden by
|
NOTE: This is not currently configurable. It will be overridden by
|
||||||
@@ -106,7 +107,7 @@ class SchedulerConfig:
|
|||||||
- "priority" means requests are handled based on given priority (lower
|
- "priority" means requests are handled based on given priority (lower
|
||||||
value means earlier handling) and time of arrival deciding any ties)."""
|
value means earlier handling) and time of arrival deciding any ties)."""
|
||||||
|
|
||||||
chunked_prefill_enabled: bool = field(init=False)
|
chunked_prefill_enabled: bool = Field(init=False)
|
||||||
"""True if chunked prefill is enabled."""
|
"""True if chunked prefill is enabled."""
|
||||||
|
|
||||||
disable_chunked_mm_input: bool = False
|
disable_chunked_mm_input: bool = False
|
||||||
@@ -155,6 +156,20 @@ class SchedulerConfig:
|
|||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
|
@field_validator(
|
||||||
|
"max_num_batched_tokens",
|
||||||
|
"max_num_seqs",
|
||||||
|
"max_model_len",
|
||||||
|
"enable_chunked_prefill",
|
||||||
|
mode="wrap",
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||||
|
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return handler(value)
|
||||||
|
|
||||||
def __post_init__(self, is_encoder_decoder: bool) -> None:
|
def __post_init__(self, is_encoder_decoder: bool) -> None:
|
||||||
if self.max_model_len is None:
|
if self.max_model_len is None:
|
||||||
self.max_model_len = 8192
|
self.max_model_len = 8192
|
||||||
@@ -260,19 +275,7 @@ class SchedulerConfig:
|
|||||||
self.max_num_seqs * self.max_model_len,
|
self.max_num_seqs * self.max_model_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_lookahead_slots < 0:
|
if self.max_num_partial_prefills > 1:
|
||||||
raise ValueError(
|
|
||||||
"num_lookahead_slots "
|
|
||||||
f"({self.num_lookahead_slots}) must be greater than or "
|
|
||||||
"equal to 0."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.max_num_partial_prefills < 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
|
|
||||||
"must be greater than or equal to 1."
|
|
||||||
)
|
|
||||||
elif self.max_num_partial_prefills > 1:
|
|
||||||
if not self.chunked_prefill_enabled:
|
if not self.chunked_prefill_enabled:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Chunked prefill must be enabled to set "
|
"Chunked prefill must be enabled to set "
|
||||||
@@ -286,13 +289,10 @@ class SchedulerConfig:
|
|||||||
f"than the max_model_len ({self.max_model_len})."
|
f"than the max_model_len ({self.max_model_len})."
|
||||||
)
|
)
|
||||||
|
|
||||||
if (self.max_long_partial_prefills < 1) or (
|
if self.max_long_partial_prefills > self.max_num_partial_prefills:
|
||||||
self.max_long_partial_prefills > self.max_num_partial_prefills
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
|
f"{self.max_long_partial_prefills=} must be less than or equal to "
|
||||||
"must be greater than or equal to 1 and less than or equal to "
|
f"{self.max_num_partial_prefills=}."
|
||||||
f"max_num_partial_prefills ({self.max_num_partial_prefills})."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -2,8 +2,9 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, Self
|
||||||
|
|
||||||
|
from pydantic import model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
@@ -56,7 +57,8 @@ class StructuredOutputsConfig:
|
|||||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def __post_init__(self):
|
@model_validator(mode="after")
|
||||||
|
def _validate_structured_output_config(self) -> Self:
|
||||||
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
|
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"disable_any_whitespace is only supported for "
|
"disable_any_whitespace is only supported for "
|
||||||
@@ -67,3 +69,4 @@ class StructuredOutputsConfig:
|
|||||||
"disable_additional_properties is only supported "
|
"disable_additional_properties is only supported "
|
||||||
"for the guidance backend."
|
"for the guidance backend."
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|||||||
@@ -1807,7 +1807,7 @@ class EngineArgs:
|
|||||||
incremental_prefill_supported = (
|
incremental_prefill_supported = (
|
||||||
pooling_type is not None
|
pooling_type is not None
|
||||||
and pooling_type.lower() == "last"
|
and pooling_type.lower() == "last"
|
||||||
and is_causal
|
and bool(is_causal)
|
||||||
)
|
)
|
||||||
|
|
||||||
action = "Enabling" if incremental_prefill_supported else "Disabling"
|
action = "Enabling" if incremental_prefill_supported else "Disabling"
|
||||||
|
|||||||
@@ -2,11 +2,12 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import regex as re
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
DeltaFunctionCall,
|
DeltaFunctionCall,
|
||||||
|
|||||||
Reference in New Issue
Block a user