Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
468 lines
16 KiB
Python
468 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Utility functions for vLLM config dataclasses."""
|
|
|
|
import ast
|
|
import enum
|
|
import hashlib
|
|
import inspect
|
|
import json
|
|
import os
|
|
import pathlib
|
|
import textwrap
|
|
from collections.abc import Callable, Mapping, Sequence, Set
|
|
from dataclasses import MISSING, field, fields, is_dataclass
|
|
from itertools import pairwise
|
|
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast, overload
|
|
|
|
import torch
|
|
from pydantic import ConfigDict
|
|
from pydantic.dataclasses import dataclass
|
|
from pydantic.fields import Field as PydanticField
|
|
from pydantic.fields import FieldInfo
|
|
from typing_extensions import dataclass_transform, runtime_checkable
|
|
|
|
import vllm.envs as envs
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from _typeshed import DataclassInstance
|
|
else:
|
|
DataclassInstance = Any
|
|
|
|
ConfigType = type[DataclassInstance]
|
|
ConfigT = TypeVar("ConfigT", bound=DataclassInstance)
|
|
|
|
|
|
@overload
|
|
@dataclass_transform(field_specifiers=(PydanticField,))
|
|
def config(cls: type[ConfigT]) -> type[ConfigT]: ...
|
|
|
|
|
|
@overload
|
|
@dataclass_transform(field_specifiers=(PydanticField,))
|
|
def config(
|
|
*, config: ConfigDict | None = None, **kwargs: Any
|
|
) -> Callable[[type[ConfigT]], type[ConfigT]]: ...
|
|
|
|
|
|
@dataclass_transform(field_specifiers=(PydanticField,))
|
|
def config(
|
|
cls: type[ConfigT] | None = None,
|
|
*,
|
|
config: ConfigDict | None = None,
|
|
**kwargs: Any,
|
|
) -> type[ConfigT] | Callable[[type[ConfigT]], type[ConfigT]]:
|
|
"""Decorator to create a pydantic dataclass with default config. The default config
|
|
for the dataclass forbids extra fields.
|
|
|
|
All config classes in vLLM should use this decorator.
|
|
|
|
Args:
|
|
cls: The class to decorate
|
|
config: The pydantic ConfigDict to use. If provided, it will be merged with
|
|
the default config.
|
|
**kwargs: Additional arguments to pass to pydantic.dataclass."""
|
|
# Extra fields are forbidden by default
|
|
merged_config = ConfigDict(extra="forbid")
|
|
if config is not None:
|
|
merged_config.update(config)
|
|
|
|
def decorator(cls: type[ConfigT]) -> type[ConfigT]:
|
|
return dataclass(cls, config=merged_config, **kwargs) # type: ignore[return-value]
|
|
|
|
# Called with arguments: @config(config=...)
|
|
if cls is None:
|
|
return decorator
|
|
# Called without arguments: @config
|
|
return decorator(cls)
|
|
|
|
|
|
def get_field(cls: ConfigType, name: str) -> Any:
|
|
"""Get the default factory field of a dataclass by name. Used for getting
|
|
default factory fields in `EngineArgs`."""
|
|
if not is_dataclass(cls):
|
|
raise TypeError("The given class is not a dataclass.")
|
|
try:
|
|
named_field = next(f for f in fields(cls) if f.name == name)
|
|
except StopIteration as e:
|
|
raise ValueError(f"Field '{name}' not found in {cls.__name__}.") from e
|
|
|
|
# The arguments to copy to the new field
|
|
default = named_field.default
|
|
default_factory = named_field.default_factory
|
|
init = named_field.init
|
|
|
|
# Handle pydantic.Field
|
|
if isinstance(default, FieldInfo):
|
|
if default.init is not None:
|
|
init = default.init
|
|
if default.default_factory is not None:
|
|
default_factory = cast(Callable[[], Any], default.default_factory)
|
|
default = MISSING
|
|
else:
|
|
default = default.default
|
|
|
|
if default is MISSING and default_factory is MISSING:
|
|
logger.warning_once(
|
|
"%s.%s has no default or default factory.", cls.__name__, name
|
|
)
|
|
return field(default=default, default_factory=default_factory, init=init)
|
|
|
|
|
|
def is_init_field(cls: ConfigType, name: str) -> bool:
|
|
return get_field(cls, name).init
|
|
|
|
|
|
def replace(dataclass_instance: ConfigT, /, **kwargs) -> ConfigT:
|
|
"""Like [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace),
|
|
but compatible with Pydantic dataclasses which use `pydantic.fields.Field` instead
|
|
of `dataclasses.field`"""
|
|
cls = type(dataclass_instance)
|
|
dataclass_dict = dataclass_instance.__dict__
|
|
dataclass_dict = {k: v for k, v in dataclass_dict.items() if is_init_field(cls, k)}
|
|
dataclass_dict.update(kwargs)
|
|
return cls(**dataclass_dict)
|
|
|
|
|
|
def getattr_iter(
|
|
object: object,
|
|
names: Sequence[str],
|
|
default: Any | None = None,
|
|
default_factory: Callable[[], Any] | None = None,
|
|
warn: bool = False,
|
|
) -> Any:
|
|
"""
|
|
A helper function that retrieves an attribute from an object which may
|
|
have multiple possible names. This is useful when fetching attributes from
|
|
arbitrary `transformers.PretrainedConfig` instances.
|
|
|
|
In the case where the first name in `names` is the preferred name, and
|
|
any other names are deprecated aliases, setting `warn=True` will log a
|
|
warning when a deprecated name is used.
|
|
"""
|
|
for i, name in enumerate(names):
|
|
if hasattr(object, name):
|
|
if warn and i > 0:
|
|
logger.warning_once(
|
|
"%s contains a deprecated attribute name '%s'. "
|
|
"Please use the preferred attribute name '%s' instead.",
|
|
type(object).__name__,
|
|
name,
|
|
names[0],
|
|
)
|
|
return getattr(object, name)
|
|
return default_factory() if default_factory is not None else default
|
|
|
|
|
|
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
|
"""
|
|
Get any docstrings placed after attribute assignments in a class body.
|
|
|
|
https://davidism.com/mit-license/
|
|
"""
|
|
|
|
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
|
|
|
if not isinstance(cls_node, ast.ClassDef):
|
|
raise TypeError("Given object was not a class.")
|
|
|
|
out = {}
|
|
|
|
# Consider each pair of nodes.
|
|
for a, b in pairwise(cls_node.body):
|
|
# Must be an assignment then a constant string.
|
|
if (
|
|
not isinstance(a, (ast.Assign, ast.AnnAssign))
|
|
or not isinstance(b, ast.Expr)
|
|
or not isinstance(b.value, ast.Constant)
|
|
or not isinstance(b.value.value, str)
|
|
):
|
|
continue
|
|
|
|
doc = inspect.cleandoc(b.value.value)
|
|
|
|
# An assignment can have multiple targets (a = b = v), but an
|
|
# annotated assignment only has one target.
|
|
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
|
|
|
|
for target in targets:
|
|
# Must be assigning to a plain name.
|
|
if not isinstance(target, ast.Name):
|
|
continue
|
|
|
|
out[target.id] = doc
|
|
|
|
return out
|
|
|
|
|
|
@runtime_checkable
|
|
class SupportsHash(Protocol):
|
|
def compute_hash(self) -> str: ...
|
|
|
|
|
|
class SupportsMetricsInfo(Protocol):
|
|
def metrics_info(self) -> dict[str, str]: ...
|
|
|
|
|
|
def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
|
|
processed_overrides = {}
|
|
for field_name, value in overrides.items():
|
|
assert hasattr(config, field_name), (
|
|
f"{type(config)} has no field `{field_name}`"
|
|
)
|
|
current_value = getattr(config, field_name)
|
|
if is_dataclass(current_value) and not is_dataclass(value):
|
|
assert isinstance(value, dict), (
|
|
f"Overrides to {type(config)}.{field_name} must be a dict"
|
|
f" or {type(current_value)}, but got {type(value)}"
|
|
)
|
|
value = update_config(
|
|
current_value, # type: ignore[type-var]
|
|
value,
|
|
)
|
|
processed_overrides[field_name] = value
|
|
return replace(config, **processed_overrides)
|
|
|
|
|
|
def normalize_value(x):
|
|
"""Return a stable, JSON-serializable canonical form for hashing.
|
|
Order: primitives, special types (Enum, callable, torch.dtype, Path), then
|
|
generic containers (Mapping/Set/Sequence) with recursion.
|
|
"""
|
|
# Fast path
|
|
if x is None or isinstance(x, (bool, int, float, str)):
|
|
return x
|
|
|
|
# Enums: tag with FQN to avoid primitive collisions.
|
|
# Ex: Enum(1) vs int(1) -> ("module.QualName", value).
|
|
if isinstance(x, enum.Enum):
|
|
enum_type = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
|
return (enum_type, normalize_value(x.value))
|
|
|
|
# Classes (types) are accepted and canonicalized by their fully-qualified
|
|
# name (module.qualname) for a stable identifier.
|
|
# Instances are only accepted if they expose uuid(); otherwise they are
|
|
# rejected to avoid under-hashing object state.
|
|
|
|
# Callables: accept classes only; reject funcs/lambdas/methods.
|
|
# Used by LogitsProcessor types and ModelConfig.hf_overrides.
|
|
if isinstance(x, type):
|
|
module = getattr(x, "__module__", "")
|
|
qual = getattr(x, "__qualname__", getattr(x, "__name__", ""))
|
|
return ".".join([p for p in (module, qual) if p]) or repr(x)
|
|
|
|
# Prefer stable uuid identifiers for objects that provide them, even if
|
|
# they are callable instances (e.g., InductorPass wrappers).
|
|
if hasattr(x, "uuid") and callable(getattr(x, "uuid", None)):
|
|
return x.uuid()
|
|
|
|
if callable(x):
|
|
raise TypeError("normalize_value: function or callable instance unsupported")
|
|
|
|
# Torch dtype: stringify (torch.float64 -> "torch.float64").
|
|
# We rely on the string form here; dtype-bearing fields that need additional
|
|
# disambiguation should encode that at the config layer.
|
|
if isinstance(x, torch.dtype):
|
|
return str(x)
|
|
|
|
# Bytes
|
|
if isinstance(x, (bytes, bytearray)):
|
|
return x.hex()
|
|
|
|
# Paths (canonicalize)
|
|
if isinstance(x, pathlib.Path):
|
|
try:
|
|
return str(x.expanduser().resolve())
|
|
except Exception:
|
|
return str(x)
|
|
|
|
# Dataclasses: represent as (FQN, sorted(field,value) tuple) for stability.
|
|
if is_dataclass(x):
|
|
type_fqn = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
|
items = tuple(
|
|
(f.name, normalize_value(getattr(x, f.name)))
|
|
for f in sorted(fields(x), key=lambda f: f.name)
|
|
)
|
|
return (type_fqn, items)
|
|
|
|
# Containers (generic)
|
|
if isinstance(x, Mapping):
|
|
return tuple(sorted((str(k), normalize_value(v)) for k, v in x.items()))
|
|
if isinstance(x, Set):
|
|
return tuple(sorted(repr(normalize_value(v)) for v in x))
|
|
if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)):
|
|
return tuple(normalize_value(v) for v in x)
|
|
|
|
# PretrainedConfig
|
|
if hasattr(x, "to_json_string") and callable(x.to_json_string):
|
|
try:
|
|
return x.to_json_string()
|
|
except (TypeError, ValueError):
|
|
# to_json_string() may fail for trust-remote-code configs
|
|
# with non-JSON-serializable nested objects. Fall back to
|
|
# normalizing the dict representation recursively.
|
|
if hasattr(x, "to_dict") and callable(x.to_dict):
|
|
return normalize_value(x.to_dict())
|
|
raise
|
|
|
|
# Unsupported type: e.g., modules, generators, open files, or objects
|
|
# without a stable JSON/UUID representation. Hard-error to avoid
|
|
# under-hashing.
|
|
# If you hit this, either reshape your config to use supported primitives
|
|
# and containers, or extend normalize_value to provide a stable encoding
|
|
# (e.g., via uuid() or to_json_string()) for this type.
|
|
raise TypeError(
|
|
f"normalize_value: unsupported type '{type(x).__name__}'. "
|
|
"Ensure config values use supported primitives/containers or add a "
|
|
"stable representation for this type."
|
|
)
|
|
|
|
|
|
def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, object]:
|
|
"""Gets the factors used for hashing a config class.
|
|
- Includes all dataclass fields not in `ignored_factors`.
|
|
- Errors on non-normalizable values.
|
|
"""
|
|
factors: dict[str, object] = {}
|
|
for dc_field in fields(config):
|
|
factor = dc_field.name
|
|
if factor in ignored_factors:
|
|
continue
|
|
value = getattr(config, factor, None)
|
|
try:
|
|
factors[factor] = normalize_value(value)
|
|
except TypeError as e:
|
|
raise TypeError(
|
|
f"get_hash_factors: unsupported type for key '{factor}' "
|
|
f"({type(value).__name__})"
|
|
) from e
|
|
return factors
|
|
|
|
|
|
def hash_factors(items: dict[str, object]) -> str:
|
|
"""Return a SHA-256 hex digest of the canonical items structure."""
|
|
return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()
|
|
|
|
|
|
@dataclass
|
|
class Range:
|
|
"""
|
|
A range of numbers.
|
|
Inclusive of start, inclusive of end.
|
|
"""
|
|
|
|
start: int
|
|
end: int
|
|
|
|
def is_single_size(self) -> bool:
|
|
return self.start == self.end
|
|
|
|
def __contains__(self, size: int) -> bool:
|
|
# Inclusive of start, inclusive of end
|
|
return self.start <= size <= self.end
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, Range):
|
|
return False
|
|
return self.start == other.start and self.end == other.end
|
|
|
|
def __hash__(self) -> int:
|
|
return hash((self.start, self.end))
|
|
|
|
def __str__(self) -> str:
|
|
return f"({self.start}, {self.end})"
|
|
|
|
def __repr__(self) -> str:
|
|
return self.__str__()
|
|
|
|
|
|
def handle_deprecated(
|
|
config: ConfigT,
|
|
old_name: str,
|
|
new_name_or_names: str | list[str],
|
|
removal_version: str,
|
|
) -> None:
|
|
old_val = getattr(config, old_name)
|
|
if old_val is None:
|
|
return
|
|
|
|
if isinstance(new_name_or_names, str):
|
|
new_names = [new_name_or_names]
|
|
else:
|
|
new_names = new_name_or_names
|
|
|
|
msg = (
|
|
f"{old_name} is deprecated and will be removed in {removal_version}. "
|
|
f"Use {', '.join(new_names)} instead."
|
|
)
|
|
logger.warning(msg)
|
|
|
|
for new_name in new_names:
|
|
setattr(config, new_name, old_val)
|
|
|
|
|
|
def get_from_deprecated_env_if_set(
|
|
env_name: str,
|
|
removal_version: str,
|
|
field_name: str | None = None,
|
|
) -> str | None:
|
|
"""
|
|
Get value from deprecated environment variable with warning.
|
|
|
|
Args:
|
|
env_name: Name of the deprecated environment variable
|
|
removal_version: Version when it will be removed
|
|
field_name: Name of the field to suggest as alternative
|
|
|
|
Returns:
|
|
The environment variable value if set, None otherwise
|
|
"""
|
|
if envs.is_set(env_name):
|
|
value = os.environ.get(env_name)
|
|
alt_msg = f" Please use {field_name} instead." if field_name else ""
|
|
logger.warning_once(
|
|
"Using %s environment variable is deprecated and will be removed in %s.%s",
|
|
env_name,
|
|
removal_version,
|
|
alt_msg,
|
|
)
|
|
return value
|
|
return None
|
|
|
|
|
|
def set_from_deprecated_env_if_set(
|
|
config: ConfigT,
|
|
env_name: str,
|
|
removal_version: str,
|
|
field_name: str,
|
|
to_bool: bool = False,
|
|
to_int: bool = False,
|
|
) -> None:
|
|
"""
|
|
Set object field from deprecated environment variable with warning.
|
|
|
|
Args:
|
|
config: Config object to set the field on
|
|
env_name: Name of the deprecated environment variable
|
|
removal_version: Version when the env var will be removed
|
|
field_name: Name of the field to set
|
|
to_bool: Whether to convert the environment variable value to boolean
|
|
to_int: Whether to convert the environment variable value to integer
|
|
Returns:
|
|
None
|
|
"""
|
|
if to_bool and to_int:
|
|
raise ValueError("Cannot convert to both boolean and integer.")
|
|
|
|
env_value = get_from_deprecated_env_if_set(env_name, removal_version, field_name)
|
|
if env_value is not None:
|
|
field_value: str | bool | int = env_value
|
|
if to_bool:
|
|
field_value = env_value.lower() in ("1", "true")
|
|
elif to_int:
|
|
field_value = int(env_value)
|
|
setattr(config, field_name, field_value)
|