[torch.compile] caching of config fields should be opt-out by default (#26468)
Signed-off-by: vnadathur <glvikramn@gmail.com> Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com> Signed-off-by: Srreyansh Sethi <srreyansh.sethi@gmail.com> Signed-off-by: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com> Co-authored-by: WorldExplored <srreyansh.sethi@gmail.com> Co-authored-by: Srreyansh Sethi <107075589+worldexplored@users.noreply.github.com> Co-authored-by: vnadathur <236933696+vnadathur@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -3,14 +3,19 @@
|
||||
"""Utility functions for vLLM config dataclasses."""
|
||||
|
||||
import ast
|
||||
import enum
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import pathlib
|
||||
import textwrap
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Mapping, Sequence, Set
|
||||
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
|
||||
from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
@@ -176,3 +181,115 @@ def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
|
||||
)
|
||||
processed_overrides[field_name] = value
|
||||
return replace(config, **processed_overrides)
|
||||
|
||||
|
||||
def normalize_value(x):
|
||||
"""Return a stable, JSON-serializable canonical form for hashing.
|
||||
Order: primitives, special types (Enum, callable, torch.dtype, Path), then
|
||||
generic containers (Mapping/Set/Sequence) with recursion.
|
||||
"""
|
||||
# Fast path
|
||||
if x is None or isinstance(x, (bool, int, float, str)):
|
||||
return x
|
||||
|
||||
# Enums: tag with FQN to avoid primitive collisions.
|
||||
# Ex: Enum(1) vs int(1) -> ("module.QualName", value).
|
||||
if isinstance(x, enum.Enum):
|
||||
enum_type = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
||||
return (enum_type, normalize_value(x.value))
|
||||
|
||||
# Classes (types) are accepted and canonicalized by their fully-qualified
|
||||
# name (module.qualname) for a stable identifier.
|
||||
# Instances are only accepted if they expose uuid(); otherwise they are
|
||||
# rejected to avoid under-hashing object state.
|
||||
|
||||
# Callables: accept classes only; reject funcs/lambdas/methods.
|
||||
# Used by LogitsProcessor types and ModelConfig.hf_overrides.
|
||||
if isinstance(x, type):
|
||||
module = getattr(x, "__module__", "")
|
||||
qual = getattr(x, "__qualname__", getattr(x, "__name__", ""))
|
||||
return ".".join([p for p in (module, qual) if p]) or repr(x)
|
||||
|
||||
# Prefer stable uuid identifiers for objects that provide them, even if
|
||||
# they are callable instances (e.g., InductorPass wrappers).
|
||||
if hasattr(x, "uuid") and callable(getattr(x, "uuid", None)):
|
||||
return x.uuid()
|
||||
|
||||
if callable(x):
|
||||
raise TypeError("normalize_value: function or callable instance unsupported")
|
||||
|
||||
# Torch dtype: stringify (torch.float64 -> "torch.float64").
|
||||
# We rely on the string form here; dtype-bearing fields that need additional
|
||||
# disambiguation should encode that at the config layer.
|
||||
if isinstance(x, torch.dtype):
|
||||
return str(x)
|
||||
|
||||
# Bytes
|
||||
if isinstance(x, (bytes, bytearray)):
|
||||
return x.hex()
|
||||
|
||||
# Paths (canonicalize)
|
||||
if isinstance(x, pathlib.Path):
|
||||
try:
|
||||
return str(x.expanduser().resolve())
|
||||
except Exception:
|
||||
return str(x)
|
||||
|
||||
# Dataclasses: represent as (FQN, sorted(field,value) tuple) for stability.
|
||||
if is_dataclass(x):
|
||||
type_fqn = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
||||
items = tuple(
|
||||
(f.name, normalize_value(getattr(x, f.name)))
|
||||
for f in sorted(fields(x), key=lambda f: f.name)
|
||||
)
|
||||
return (type_fqn, items)
|
||||
|
||||
# Containers (generic)
|
||||
if isinstance(x, Mapping):
|
||||
return tuple(sorted((str(k), normalize_value(v)) for k, v in x.items()))
|
||||
if isinstance(x, Set):
|
||||
return tuple(sorted(repr(normalize_value(v)) for v in x))
|
||||
if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)):
|
||||
return tuple(normalize_value(v) for v in x)
|
||||
|
||||
# PretrainedConfig
|
||||
if hasattr(x, "to_json_string") and callable(x.to_json_string):
|
||||
return x.to_json_string()
|
||||
|
||||
# Unsupported type: e.g., modules, generators, open files, or objects
|
||||
# without a stable JSON/UUID representation. Hard-error to avoid
|
||||
# under-hashing.
|
||||
# If you hit this, either reshape your config to use supported primitives
|
||||
# and containers, or extend normalize_value to provide a stable encoding
|
||||
# (e.g., via uuid() or to_json_string()) for this type.
|
||||
raise TypeError(
|
||||
f"normalize_value: unsupported type '{type(x).__name__}'. "
|
||||
"Ensure config values use supported primitives/containers or add a "
|
||||
"stable representation for this type."
|
||||
)
|
||||
|
||||
|
||||
def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, object]:
|
||||
"""Gets the factors used for hashing a config class.
|
||||
- Includes all dataclass fields not in `ignored_factors`.
|
||||
- Errors on non-normalizable values.
|
||||
"""
|
||||
factors: dict[str, object] = {}
|
||||
for dc_field in fields(config):
|
||||
factor = dc_field.name
|
||||
if factor in ignored_factors:
|
||||
continue
|
||||
value = getattr(config, factor, None)
|
||||
try:
|
||||
factors[factor] = normalize_value(value)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
f"get_hash_factors: unsupported type for key '{factor}' "
|
||||
f"({type(value).__name__})"
|
||||
) from e
|
||||
return factors
|
||||
|
||||
|
||||
def hash_factors(items: dict[str, object]) -> str:
|
||||
"""Return a SHA-256 hex digest of the canonical items structure."""
|
||||
return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
Reference in New Issue
Block a user