Turn @config into a dataclass_transform (#31541)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,11 +3,11 @@
|
||||
|
||||
import json
|
||||
from argparse import ArgumentError
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import AttentionConfig, CompilationConfig, config
|
||||
from vllm.engine.arg_utils import (
|
||||
@@ -96,7 +96,7 @@ def test_get_type(type_hints, type, expected):
|
||||
],
|
||||
)
|
||||
def test_literal_to_kwargs(type_hints, expected):
|
||||
context = nullcontext()
|
||||
context: AbstractContextManager[object] = nullcontext()
|
||||
if expected is Exception:
|
||||
context = pytest.raises(expected)
|
||||
with context:
|
||||
@@ -104,14 +104,12 @@ def test_literal_to_kwargs(type_hints, expected):
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class NestedConfig:
|
||||
field: int = 1
|
||||
"""field"""
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DummyConfig:
|
||||
regular_bool: bool = True
|
||||
"""Regular bool with default True"""
|
||||
@@ -119,23 +117,23 @@ class DummyConfig:
|
||||
"""Optional bool with default None"""
|
||||
optional_literal: Literal["x", "y"] | None = None
|
||||
"""Optional literal with default None"""
|
||||
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
|
||||
tuple_n: tuple[int, ...] = Field(default_factory=lambda: (1, 2, 3))
|
||||
"""Tuple with variable length"""
|
||||
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
|
||||
tuple_2: tuple[int, int] = Field(default_factory=lambda: (1, 2))
|
||||
"""Tuple with fixed length"""
|
||||
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
|
||||
list_n: list[int] = Field(default_factory=lambda: [1, 2, 3])
|
||||
"""List with variable length"""
|
||||
list_literal: list[Literal[1, 2]] = field(default_factory=list)
|
||||
list_literal: list[Literal[1, 2]] = Field(default_factory=list)
|
||||
"""List with literal choices"""
|
||||
list_union: list[str | type[object]] = field(default_factory=list)
|
||||
list_union: list[str | type[object]] = Field(default_factory=list)
|
||||
"""List with union type"""
|
||||
set_n: set[int] = field(default_factory=lambda: {1, 2, 3})
|
||||
set_n: set[int] = Field(default_factory=lambda: {1, 2, 3})
|
||||
"""Set with variable length"""
|
||||
literal_literal: Literal[Literal[1], Literal[2]] = 1
|
||||
"""Literal of literals with default 1"""
|
||||
json_tip: dict = field(default_factory=dict)
|
||||
json_tip: dict = Field(default_factory=dict)
|
||||
"""Dict which will be JSON in CLI"""
|
||||
nested_config: NestedConfig = field(default_factory=NestedConfig)
|
||||
nested_config: NestedConfig = Field(default_factory=NestedConfig)
|
||||
"""Nested config"""
|
||||
|
||||
|
||||
@@ -195,7 +193,7 @@ def test_get_kwargs():
|
||||
json_tip = "Should either be a valid JSON string or JSON keys"
|
||||
assert json_tip in kwargs["json_tip"]["help"]
|
||||
# nested config should construct the nested config
|
||||
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
|
||||
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) # type: ignore[call-arg]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -66,9 +66,6 @@ class _TestConfigFields:
|
||||
|
||||
|
||||
def test_get_field():
|
||||
with pytest.raises(ValueError):
|
||||
get_field(_TestConfigFields, "a")
|
||||
|
||||
b = get_field(_TestConfigFields, "b")
|
||||
assert isinstance(b, Field)
|
||||
assert b.default is MISSING
|
||||
@@ -188,7 +185,7 @@ def test_get_pooling_config():
|
||||
)
|
||||
def test_get_pooling_config_from_args():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
|
||||
pooler_config = PoolerConfig(seq_pooling_type="CLS", use_activation=False)
|
||||
model_config = ModelConfig(model_id, pooler_config=pooler_config)
|
||||
|
||||
assert asdict(model_config.pooler_config) == asdict(pooler_config)
|
||||
|
||||
@@ -7,31 +7,22 @@ import pytest
|
||||
|
||||
from tools.pre_commit.validate_config import validate_ast
|
||||
|
||||
_TestConfig1 = """
|
||||
_TestConfig1 = '''
|
||||
@config
|
||||
class _TestConfig1:
|
||||
pass
|
||||
"""
|
||||
|
||||
_TestConfig2 = '''
|
||||
@config
|
||||
@dataclass
|
||||
class _TestConfig2:
|
||||
a: int
|
||||
"""docstring"""
|
||||
'''
|
||||
|
||||
_TestConfig3 = """
|
||||
_TestConfig2 = """
|
||||
@config
|
||||
@dataclass
|
||||
class _TestConfig3:
|
||||
class _TestConfig2:
|
||||
a: int = 1
|
||||
"""
|
||||
|
||||
_TestConfig4 = '''
|
||||
_TestConfig3 = '''
|
||||
@config
|
||||
@dataclass
|
||||
class _TestConfig4:
|
||||
class _TestConfig3:
|
||||
a: Union[Literal[1], Literal[2]] = 1
|
||||
"""docstring"""
|
||||
'''
|
||||
@@ -40,10 +31,9 @@ class _TestConfig4:
|
||||
@pytest.mark.parametrize(
|
||||
("test_config", "expected_error"),
|
||||
[
|
||||
(_TestConfig1, "must be a dataclass"),
|
||||
(_TestConfig2, "must have a default"),
|
||||
(_TestConfig3, "must have a docstring"),
|
||||
(_TestConfig4, "must use a single Literal"),
|
||||
(_TestConfig1, "must have a default"),
|
||||
(_TestConfig2, "must have a docstring"),
|
||||
(_TestConfig3, "must use a single Literal"),
|
||||
],
|
||||
)
|
||||
def test_config(test_config, expected_error):
|
||||
|
||||
@@ -766,8 +766,8 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
|
||||
"max_model_len": args.max_model_len,
|
||||
"enforce_eager": enforce_eager,
|
||||
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
|
||||
"max_num_seqs": 100, # limit cudagraph capture runtime
|
||||
},
|
||||
max_num_seqs=100, # limit cudagraph capture runtime
|
||||
max_model_len=args.max_model_len,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
tensor_parallel_size=args.target_tensor_parallel_size,
|
||||
|
||||
@@ -26,11 +26,8 @@ def test_backend_guidance_rollback_terminated():
|
||||
# guidance backend. In that case we are in a stopped state, but
|
||||
# it should be reverted in case EOS is not accepted by the target
|
||||
# model.
|
||||
vllm_config = VllmConfig(
|
||||
decoding_config=StructuredOutputsConfig(
|
||||
backend="guidance",
|
||||
)
|
||||
)
|
||||
structured_outputs_config = StructuredOutputsConfig(backend="guidance")
|
||||
vllm_config = VllmConfig(structured_outputs_config=structured_outputs_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
|
||||
|
||||
backend = GuidanceBackend(
|
||||
|
||||
@@ -54,24 +54,18 @@ class ConfigValidator(ast.NodeVisitor):
|
||||
def __init__(self): ...
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
# Validate class with both @config and @dataclass decorators
|
||||
decorators = [
|
||||
id
|
||||
for d in node.decorator_list
|
||||
if (
|
||||
isinstance(d, ast.Name)
|
||||
and ((id := d.id) == "config" or id == "dataclass")
|
||||
)
|
||||
or (
|
||||
isinstance(d, ast.Call)
|
||||
and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass")
|
||||
)
|
||||
]
|
||||
# Validate classes with a @config decorator
|
||||
decorators = set()
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Call):
|
||||
decorator = decorator.func
|
||||
if isinstance(decorator, ast.Name) and decorator.id == "config":
|
||||
decorators.add(decorator.id)
|
||||
|
||||
if set(decorators) == {"config", "dataclass"}:
|
||||
if decorators == {"config"}:
|
||||
validate_class(node)
|
||||
elif set(decorators) == {"config"}:
|
||||
fail(f"Class {node.name} with config decorator must be a dataclass.", node)
|
||||
elif "config" in decorators:
|
||||
fail(f"config decorator for {node.name} should be used alone", node)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from vllm.config.utils import (
|
||||
config,
|
||||
get_attr_docs,
|
||||
is_init_field,
|
||||
replace,
|
||||
update_config,
|
||||
)
|
||||
from vllm.config.vllm import (
|
||||
@@ -101,6 +102,7 @@ __all__ = [
|
||||
"config",
|
||||
"get_attr_docs",
|
||||
"is_init_field",
|
||||
"replace",
|
||||
"update_config",
|
||||
# From vllm.config.vllm
|
||||
"VllmConfig",
|
||||
|
||||
@@ -4,14 +4,12 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class AttentionConfig:
|
||||
"""Configuration for attention mechanisms in vLLM."""
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import Field, SkipValidation, field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
@@ -37,7 +36,6 @@ KVOffloadingBackend = Literal["native", "lmcache"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache."""
|
||||
|
||||
|
||||
@@ -8,8 +8,7 @@ from dataclasses import field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||
|
||||
from pydantic import ConfigDict, Field, TypeAdapter, field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic import Field, TypeAdapter, field_validator
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
@@ -96,7 +95,6 @@ class CUDAGraphMode(enum.Enum):
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class PassConfig:
|
||||
"""Configuration for custom Inductor passes.
|
||||
|
||||
@@ -267,7 +265,6 @@ class DynamicShapesType(str, enum.Enum):
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class DynamicShapesConfig:
|
||||
"""Configuration to control/debug torch compile dynamic shapes."""
|
||||
|
||||
@@ -311,7 +308,6 @@ class DynamicShapesConfig:
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class CompilationConfig:
|
||||
"""Configuration for compilation.
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, SkipValidation
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
@@ -14,8 +13,7 @@ from vllm.utils.hashing import safe_hash
|
||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
@config(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class DeviceConfig:
|
||||
"""Configuration for the device to use for vLLM execution."""
|
||||
|
||||
|
||||
@@ -5,8 +5,6 @@ import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
ECProducer = Literal["ec_producer"]
|
||||
@@ -15,7 +13,6 @@ ECRole = Literal[ECProducer, ECConsumer]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ECTransferConfig:
|
||||
"""Configuration for distributed EC cache transfer."""
|
||||
|
||||
|
||||
@@ -5,13 +5,11 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class KVEventsConfig:
|
||||
"""Configuration for KV event publishing."""
|
||||
|
||||
|
||||
@@ -5,8 +5,6 @@ import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
@@ -16,7 +14,6 @@ KVRole = Literal[KVProducer, KVConsumer]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class KVTransferConfig:
|
||||
"""Configuration for distributed KV cache transfer."""
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
@@ -21,7 +20,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""Configuration for loading the model weights."""
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
@@ -26,8 +25,7 @@ MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
|
||||
LoRAExtraVocabSize = Literal[256, 512]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
@config(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class LoRAConfig:
|
||||
"""Configuration for LoRA."""
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast, get_args
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.model_arch import (
|
||||
@@ -97,8 +96,7 @@ AttnTypeStr = Literal[
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
@config(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class ModelConfig:
|
||||
"""Configuration for the model."""
|
||||
|
||||
|
||||
@@ -51,7 +51,6 @@ DummyOptions: TypeAlias = (
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class MultiModalConfig:
|
||||
"""Controls the behavior of multimodal models."""
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Literal, cast
|
||||
|
||||
from packaging.version import parse
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm import version
|
||||
from vllm.config.utils import config
|
||||
@@ -16,7 +15,6 @@ DetailedTraceModules = Literal["model", "worker", "all"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ObservabilityConfig:
|
||||
"""Configuration for observability - metrics and tracing."""
|
||||
|
||||
|
||||
@@ -3,12 +3,10 @@
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -50,7 +48,6 @@ All2AllBackend = Literal[
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class EPLBConfig:
|
||||
"""Configuration for Expert Parallel Load Balancing (EP)."""
|
||||
|
||||
@@ -94,7 +91,6 @@ class EPLBConfig:
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution."""
|
||||
|
||||
@@ -715,6 +711,3 @@ class ParallelConfig:
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def replace(self, **kwargs) -> Self:
|
||||
return replace(self, **kwargs)
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
@@ -19,7 +17,6 @@ TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import os
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
@@ -32,7 +31,6 @@ def _is_uri_path(path: str) -> bool:
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ProfilerConfig:
|
||||
"""Dataclass which contains profiler config for the engine."""
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from dataclasses import InitVar
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
@@ -24,7 +23,6 @@ SchedulerPolicy = Literal["fcfs", "priority"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class SchedulerConfig:
|
||||
"""Scheduler configuration."""
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import ast
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.model import ModelConfig
|
||||
@@ -55,7 +54,6 @@ SpeculativeMethod = Literal[
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class SpeculativeConfig:
|
||||
"""Configuration for speculative decoding."""
|
||||
|
||||
|
||||
@@ -2,13 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class SpeechToTextConfig:
|
||||
"""Configuration for speech-to-text models."""
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
@@ -16,7 +15,6 @@ StructuredOutputsBackend = Literal[
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class StructuredOutputsConfig:
|
||||
"""Dataclass which contains structured outputs config for the engine."""
|
||||
|
||||
|
||||
@@ -10,14 +10,17 @@ import json
|
||||
import pathlib
|
||||
import textwrap
|
||||
from collections.abc import Callable, Mapping, Sequence, Set
|
||||
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
|
||||
from dataclasses import MISSING, Field, field, fields, is_dataclass
|
||||
from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.fields import Field as PydanticField
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import runtime_checkable
|
||||
from typing_extensions import dataclass_transform, runtime_checkable
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -29,23 +32,39 @@ else:
|
||||
DataclassInstance = Any
|
||||
|
||||
ConfigType = type[DataclassInstance]
|
||||
ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
||||
ConfigT = TypeVar("ConfigT", bound=DataclassInstance)
|
||||
|
||||
|
||||
def config(cls: ConfigT) -> ConfigT:
|
||||
"""
|
||||
A decorator that ensures all fields in a dataclass have default values
|
||||
and that each field has a docstring.
|
||||
@dataclass_transform(field_specifiers=(PydanticField,))
|
||||
def config(
|
||||
cls: type[ConfigT] | None = None,
|
||||
*,
|
||||
config: ConfigDict | None = None,
|
||||
**kwargs: Any,
|
||||
) -> type[ConfigT] | Callable[[type[ConfigT]], type[ConfigT]]:
|
||||
"""Decorator to create a pydantic dataclass with default config. The default config
|
||||
for the dataclass forbids extra fields.
|
||||
|
||||
If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument
|
||||
provided by `get_kwargs` will be
|
||||
`pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the
|
||||
`cli_arg` as a JSON string which gets validated by `pydantic`.
|
||||
All config classes in vLLM should use this decorator.
|
||||
|
||||
Config validation is performed by the tools/pre_commit/validate_config.py
|
||||
script, which is invoked during the pre-commit checks.
|
||||
"""
|
||||
return cls
|
||||
Args:
|
||||
cls: The class to decorate
|
||||
config: The pydantic ConfigDict to use. If provided, it will be merged with
|
||||
the default config.
|
||||
**kwargs: Additional arguments to pass to pydantic.dataclass."""
|
||||
# Extra fields are forbidden by default
|
||||
merged_config = ConfigDict(extra="forbid")
|
||||
if config is not None:
|
||||
merged_config.update(config)
|
||||
|
||||
def decorator(cls):
|
||||
return dataclass(cls, config=merged_config, **kwargs)
|
||||
|
||||
# Called with arguments: @config(config=...)
|
||||
if cls is None:
|
||||
return decorator
|
||||
# Called without arguments: @config
|
||||
return decorator(cls)
|
||||
|
||||
|
||||
def get_field(cls: ConfigType, name: str) -> Field:
|
||||
@@ -53,24 +72,46 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
default factory fields in `EngineArgs`."""
|
||||
if not is_dataclass(cls):
|
||||
raise TypeError("The given class is not a dataclass.")
|
||||
cls_fields = {f.name: f for f in fields(cls)}
|
||||
if name not in cls_fields:
|
||||
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
|
||||
named_field: Field = cls_fields[name]
|
||||
if (default_factory := named_field.default_factory) is not MISSING:
|
||||
return field(default_factory=default_factory)
|
||||
if (default := named_field.default) is not MISSING:
|
||||
if isinstance(default, FieldInfo):
|
||||
# Handle pydantic.Field defaults
|
||||
if default.default_factory is not None:
|
||||
return field(default_factory=default.default_factory)
|
||||
else:
|
||||
default = default.default
|
||||
return field(default=default)
|
||||
try:
|
||||
named_field = next(f for f in fields(cls) if f.name == name)
|
||||
except StopIteration as e:
|
||||
raise ValueError(f"Field '{name}' not found in {cls.__name__}.") from e
|
||||
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.{name} must have a default value or default factory."
|
||||
)
|
||||
# The arguments to copy to the new field
|
||||
default = named_field.default
|
||||
default_factory = named_field.default_factory
|
||||
init = named_field.init
|
||||
|
||||
# Handle pydantic.Field
|
||||
if isinstance(default, FieldInfo):
|
||||
if default.init is not None:
|
||||
init = default.init
|
||||
if default.default_factory is not None:
|
||||
default_factory = cast(Callable[[], Any], default.default_factory)
|
||||
default = MISSING
|
||||
else:
|
||||
default = default.default
|
||||
|
||||
if default is MISSING and default_factory is MISSING:
|
||||
logger.warning_once(
|
||||
"%s.%s has no default or default factory.", cls.__name__, name
|
||||
)
|
||||
return field(default=default, default_factory=default_factory, init=init)
|
||||
|
||||
|
||||
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
return get_field(cls, name).init
|
||||
|
||||
|
||||
def replace(dataclass_instance: ConfigT, /, **kwargs) -> ConfigT:
|
||||
"""Like [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace),
|
||||
but compatible with Pydantic dataclasses which use `pydantic.fields.Field` instead
|
||||
of `dataclasses.field`"""
|
||||
cls = type(dataclass_instance)
|
||||
dataclass_dict = dataclass_instance.__dict__
|
||||
dataclass_dict = {k: v for k, v in dataclass_dict.items() if is_init_field(cls, k)}
|
||||
dataclass_dict.update(kwargs)
|
||||
return cls(**dataclass_dict)
|
||||
|
||||
|
||||
def getattr_iter(
|
||||
@@ -172,10 +213,6 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||
return out
|
||||
|
||||
|
||||
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
return next(f for f in fields(cls) if f.name == name).init
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsHash(Protocol):
|
||||
def compute_hash(self) -> str: ...
|
||||
|
||||
@@ -9,7 +9,7 @@ import tempfile
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import is_dataclass, replace
|
||||
from dataclasses import is_dataclass
|
||||
from datetime import datetime
|
||||
from enum import IntEnum
|
||||
from functools import lru_cache
|
||||
@@ -18,10 +18,8 @@ from typing import TYPE_CHECKING, Any, TypeVar, get_args
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.speculative import EagleModelTypes
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
||||
from vllm.utils import random_uuid
|
||||
@@ -41,9 +39,9 @@ from .observability import ObservabilityConfig
|
||||
from .parallel import ParallelConfig
|
||||
from .profiler import ProfilerConfig
|
||||
from .scheduler import SchedulerConfig
|
||||
from .speculative import SpeculativeConfig
|
||||
from .speculative import EagleModelTypes, SpeculativeConfig
|
||||
from .structured_outputs import StructuredOutputsConfig
|
||||
from .utils import SupportsHash, config
|
||||
from .utils import SupportsHash, config, replace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
@@ -187,8 +185,7 @@ OPTIMIZATION_LEVEL_TO_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
@config(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class VllmConfig:
|
||||
"""Dataclass which contains all vllm-related configuration. This
|
||||
simplifies passing around the distinct configurations in the codebase.
|
||||
@@ -1395,14 +1392,6 @@ class VllmConfig:
|
||||
path = self.compilation_config.debug_dump_path / append_path
|
||||
return path
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""
|
||||
Replace attributes of the config, and 'recompute' the config.
|
||||
dataclass.replace() calls __init__() and __post_init__(), source:
|
||||
https://docs.python.org/3/library/dataclasses.html#dataclasses.replace
|
||||
"""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"model={self.model_config.model!r}, "
|
||||
|
||||
@@ -13,8 +13,6 @@ from collections.abc import Sequence
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import config
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
@@ -69,7 +67,6 @@ class LoRAParserAction(argparse.Action):
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class FrontendArgs:
|
||||
"""Arguments for the OpenAI-compatible frontend server."""
|
||||
|
||||
|
||||
@@ -13,9 +13,10 @@ def register_speculator(name):
|
||||
|
||||
|
||||
@register_speculator("eagle3")
|
||||
def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
|
||||
def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None:
|
||||
"""
|
||||
Apply Eagle-3 specific configuration transformations.
|
||||
Apply Eagle-3 specific configuration transformations to the `dict` used to
|
||||
construct the Transformers PreTrainedConfig.
|
||||
|
||||
Eagle-3 specific fields:
|
||||
- draft_vocab_size: Size of the draft model's vocabulary
|
||||
@@ -27,12 +28,14 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
|
||||
predictions. This is the standard field used in Eagle3 checkpoints.
|
||||
"""
|
||||
|
||||
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
|
||||
pre_trained_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
|
||||
if config_dict.get("target_hidden_size") is not None:
|
||||
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
|
||||
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True)
|
||||
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
|
||||
pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"]
|
||||
pre_trained_config["norm_before_residual"] = config_dict.get(
|
||||
"norm_before_residual", True
|
||||
)
|
||||
pre_trained_config["architectures"] = ["Eagle3LlamaForCausalLM"]
|
||||
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
|
||||
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
|
||||
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
|
||||
"eagle_aux_hidden_state_layer_ids"
|
||||
]
|
||||
|
||||
@@ -24,13 +24,16 @@ class SpeculatorsConfig(PretrainedConfig):
|
||||
"""Load speculators Eagle config and convert to vLLM format."""
|
||||
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
vllm_config = cls.extract_vllm_speculative_config(config_dict)
|
||||
vllm_config = cls.extract_transformers_pre_trained_config(config_dict)
|
||||
return cls(**vllm_config)
|
||||
|
||||
@classmethod
|
||||
def extract_vllm_speculative_config(
|
||||
def extract_transformers_pre_trained_config(
|
||||
cls, config_dict: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract standard Transformers PreTrainedConfig config from speculators config.
|
||||
"""
|
||||
speculators_model_type = config_dict.get("speculators_model_type")
|
||||
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
|
||||
raise ValueError(
|
||||
@@ -38,15 +41,23 @@ class SpeculatorsConfig(PretrainedConfig):
|
||||
"Please ensure you're loading a speculators-format model."
|
||||
)
|
||||
|
||||
# Start with transformer layer configuration if present
|
||||
pre_trained_config = config_dict.get("transformer_layer_config", {})
|
||||
# Apply anything specific to the supported algorithm
|
||||
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
|
||||
algo_updater(config_dict=config_dict, pre_trained_config=pre_trained_config)
|
||||
return pre_trained_config
|
||||
|
||||
@classmethod
|
||||
def extract_vllm_speculative_config(
|
||||
cls, config_dict: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Extract vLLM speculative config from speculators config."""
|
||||
# validate fields
|
||||
# TODO: @dsikka - use speculators pydantic model to validate
|
||||
cls.validate_speculators_config(config_dict=config_dict)
|
||||
# Convert from speculators config -> format that can be ingested by vLLM
|
||||
vllm_config = cls.build_vllm_speculative_config(config_dict=config_dict)
|
||||
# Apply anything specific to the supported algorithm
|
||||
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
|
||||
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
|
||||
return vllm_config
|
||||
return cls.build_vllm_speculative_config(config_dict=config_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
|
||||
@@ -101,14 +112,7 @@ class SpeculatorsConfig(PretrainedConfig):
|
||||
)
|
||||
|
||||
# Build base vLLM speculative configuration
|
||||
vllm_config = {
|
||||
return {
|
||||
"method": config_dict.get("speculators_model_type"),
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
"target_model": spec_config.get("verifier")["name_or_path"],
|
||||
}
|
||||
|
||||
# Merge transformer layer configuration if present
|
||||
transformer_config = config_dict.get("transformer_layer_config", {})
|
||||
vllm_config.update(transformer_config)
|
||||
|
||||
return vllm_config
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config, replace
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@@ -191,10 +191,12 @@ def create_vllm_config_for_draft_model(
|
||||
old = target_model_vllm_config
|
||||
assert old.speculative_config is not None, "speculative_config is not set"
|
||||
old_spec_config = old.speculative_config
|
||||
new_parallel_config = old_spec_config.draft_parallel_config.replace(
|
||||
rank=old.parallel_config.rank
|
||||
new_parallel_config = replace(
|
||||
old_spec_config.draft_parallel_config,
|
||||
rank=old.parallel_config.rank,
|
||||
)
|
||||
new: VllmConfig = old.replace(
|
||||
new: VllmConfig = replace(
|
||||
old,
|
||||
quant_config=None, # quant_config is recomputed in __init__()
|
||||
model_config=old_spec_config.draft_model_config,
|
||||
parallel_config=new_parallel_config,
|
||||
|
||||
Reference in New Issue
Block a user