Turn @config into a dataclass_transform (#31541)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-02-03 17:40:59 +00:00
committed by GitHub
parent b1bb18de8d
commit 61e632aea1
32 changed files with 153 additions and 191 deletions

View File

@@ -3,11 +3,11 @@
import json
from argparse import ArgumentError
from contextlib import nullcontext
from dataclasses import dataclass, field
from contextlib import AbstractContextManager, nullcontext
from typing import Annotated, Literal
import pytest
from pydantic import Field
from vllm.config import AttentionConfig, CompilationConfig, config
from vllm.engine.arg_utils import (
@@ -96,7 +96,7 @@ def test_get_type(type_hints, type, expected):
],
)
def test_literal_to_kwargs(type_hints, expected):
context = nullcontext()
context: AbstractContextManager[object] = nullcontext()
if expected is Exception:
context = pytest.raises(expected)
with context:
@@ -104,14 +104,12 @@ def test_literal_to_kwargs(type_hints, expected):
@config
@dataclass
class NestedConfig:
field: int = 1
"""field"""
@config
@dataclass
class DummyConfig:
regular_bool: bool = True
"""Regular bool with default True"""
@@ -119,23 +117,23 @@ class DummyConfig:
"""Optional bool with default None"""
optional_literal: Literal["x", "y"] | None = None
"""Optional literal with default None"""
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
tuple_n: tuple[int, ...] = Field(default_factory=lambda: (1, 2, 3))
"""Tuple with variable length"""
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
tuple_2: tuple[int, int] = Field(default_factory=lambda: (1, 2))
"""Tuple with fixed length"""
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
list_n: list[int] = Field(default_factory=lambda: [1, 2, 3])
"""List with variable length"""
list_literal: list[Literal[1, 2]] = field(default_factory=list)
list_literal: list[Literal[1, 2]] = Field(default_factory=list)
"""List with literal choices"""
list_union: list[str | type[object]] = field(default_factory=list)
list_union: list[str | type[object]] = Field(default_factory=list)
"""List with union type"""
set_n: set[int] = field(default_factory=lambda: {1, 2, 3})
set_n: set[int] = Field(default_factory=lambda: {1, 2, 3})
"""Set with variable length"""
literal_literal: Literal[Literal[1], Literal[2]] = 1
"""Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict)
json_tip: dict = Field(default_factory=dict)
"""Dict which will be JSON in CLI"""
nested_config: NestedConfig = field(default_factory=NestedConfig)
nested_config: NestedConfig = Field(default_factory=NestedConfig)
"""Nested config"""
@@ -195,7 +193,7 @@ def test_get_kwargs():
json_tip = "Should either be a valid JSON string or JSON keys"
assert json_tip in kwargs["json_tip"]["help"]
# nested config should construct the nested config
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) # type: ignore[call-arg]
@pytest.mark.parametrize(

View File

@@ -66,9 +66,6 @@ class _TestConfigFields:
def test_get_field():
with pytest.raises(ValueError):
get_field(_TestConfigFields, "a")
b = get_field(_TestConfigFields, "b")
assert isinstance(b, Field)
assert b.default is MISSING
@@ -188,7 +185,7 @@ def test_get_pooling_config():
)
def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
pooler_config = PoolerConfig(seq_pooling_type="CLS", use_activation=False)
model_config = ModelConfig(model_id, pooler_config=pooler_config)
assert asdict(model_config.pooler_config) == asdict(pooler_config)

View File

@@ -7,31 +7,22 @@ import pytest
from tools.pre_commit.validate_config import validate_ast
_TestConfig1 = """
_TestConfig1 = '''
@config
class _TestConfig1:
pass
"""
_TestConfig2 = '''
@config
@dataclass
class _TestConfig2:
a: int
"""docstring"""
'''
_TestConfig3 = """
_TestConfig2 = """
@config
@dataclass
class _TestConfig3:
class _TestConfig2:
a: int = 1
"""
_TestConfig4 = '''
_TestConfig3 = '''
@config
@dataclass
class _TestConfig4:
class _TestConfig3:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""
'''
@@ -40,10 +31,9 @@ class _TestConfig4:
@pytest.mark.parametrize(
("test_config", "expected_error"),
[
(_TestConfig1, "must be a dataclass"),
(_TestConfig2, "must have a default"),
(_TestConfig3, "must have a docstring"),
(_TestConfig4, "must use a single Literal"),
(_TestConfig1, "must have a default"),
(_TestConfig2, "must have a docstring"),
(_TestConfig3, "must use a single Literal"),
],
)
def test_config(test_config, expected_error):

View File

@@ -766,8 +766,8 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"max_model_len": args.max_model_len,
"enforce_eager": enforce_eager,
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
"max_num_seqs": 100, # limit cudagraph capture runtime
},
max_num_seqs=100, # limit cudagraph capture runtime
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,

View File

@@ -26,11 +26,8 @@ def test_backend_guidance_rollback_terminated():
# guidance backend. In that case we are in a stopped state, but
# it should be reverted in case EOS is not accepted by the target
# model.
vllm_config = VllmConfig(
decoding_config=StructuredOutputsConfig(
backend="guidance",
)
)
structured_outputs_config = StructuredOutputsConfig(backend="guidance")
vllm_config = VllmConfig(structured_outputs_config=structured_outputs_config)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
backend = GuidanceBackend(

View File

@@ -54,24 +54,18 @@ class ConfigValidator(ast.NodeVisitor):
def __init__(self): ...
def visit_ClassDef(self, node):
# Validate class with both @config and @dataclass decorators
decorators = [
id
for d in node.decorator_list
if (
isinstance(d, ast.Name)
and ((id := d.id) == "config" or id == "dataclass")
)
or (
isinstance(d, ast.Call)
and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass")
)
]
# Validate classes with a @config decorator
decorators = set()
for decorator in node.decorator_list:
if isinstance(decorator, ast.Call):
decorator = decorator.func
if isinstance(decorator, ast.Name) and decorator.id == "config":
decorators.add(decorator.id)
if set(decorators) == {"config", "dataclass"}:
if decorators == {"config"}:
validate_class(node)
elif set(decorators) == {"config"}:
fail(f"Class {node.name} with config decorator must be a dataclass.", node)
elif "config" in decorators:
fail(f"config decorator for {node.name} should be used alone", node)
self.generic_visit(node)

View File

@@ -36,6 +36,7 @@ from vllm.config.utils import (
config,
get_attr_docs,
is_init_field,
replace,
update_config,
)
from vllm.config.vllm import (
@@ -101,6 +102,7 @@ __all__ = [
"config",
"get_attr_docs",
"is_init_field",
"replace",
"update_config",
# From vllm.config.vllm
"VllmConfig",

View File

@@ -4,14 +4,12 @@
from typing import Any, Literal
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.v1.attention.backends.registry import AttentionBackendEnum
@config
@dataclass
class AttentionConfig:
"""Configuration for attention mechanisms in vLLM."""

View File

@@ -6,7 +6,6 @@ from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal
from pydantic import Field, SkipValidation, field_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
@@ -37,7 +36,6 @@ KVOffloadingBackend = Literal["native", "lmcache"]
@config
@dataclass
class CacheConfig:
"""Configuration for the KV cache."""

View File

@@ -8,8 +8,7 @@ from dataclasses import field
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from pydantic import ConfigDict, Field, TypeAdapter, field_validator
from pydantic.dataclasses import dataclass
from pydantic import Field, TypeAdapter, field_validator
import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
@@ -96,7 +95,6 @@ class CUDAGraphMode(enum.Enum):
@config
@dataclass(config=ConfigDict(extra="forbid"))
class PassConfig:
"""Configuration for custom Inductor passes.
@@ -267,7 +265,6 @@ class DynamicShapesType(str, enum.Enum):
@config
@dataclass(config=ConfigDict(extra="forbid"))
class DynamicShapesConfig:
"""Configuration to control/debug torch compile dynamic shapes."""
@@ -311,7 +308,6 @@ class DynamicShapesConfig:
@config
@dataclass(config=ConfigDict(extra="forbid"))
class CompilationConfig:
"""Configuration for compilation.

View File

@@ -6,7 +6,6 @@ from typing import Any, Literal
import torch
from pydantic import ConfigDict, SkipValidation
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
@@ -14,8 +13,7 @@ from vllm.utils.hashing import safe_hash
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@config(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""

View File

@@ -5,8 +5,6 @@ import uuid
from dataclasses import field
from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
ECProducer = Literal["ec_producer"]
@@ -15,7 +13,6 @@ ECRole = Literal[ECProducer, ECConsumer]
@config
@dataclass
class ECTransferConfig:
"""Configuration for distributed EC cache transfer."""

View File

@@ -5,13 +5,11 @@
from typing import Literal
from pydantic import Field
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@config
@dataclass
class KVEventsConfig:
"""Configuration for KV event publishing."""

View File

@@ -5,8 +5,6 @@ import uuid
from dataclasses import field
from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
@@ -16,7 +14,6 @@ KVRole = Literal[KVProducer, KVConsumer]
@config
@dataclass
class KVTransferConfig:
"""Configuration for distributed KV cache transfer."""

View File

@@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Any
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
@@ -21,7 +20,6 @@ logger = init_logger(__name__)
@config
@dataclass
class LoadConfig:
"""Configuration for loading the model weights."""

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Literal
import torch
from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config
@@ -26,8 +25,7 @@ MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
LoRAExtraVocabSize = Literal[256, 512]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@config(config=ConfigDict(arbitrary_types_allowed=True))
class LoRAConfig:
"""Configuration for LoRA."""

View File

@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config.model_arch import (
@@ -97,8 +96,7 @@ AttnTypeStr = Literal[
]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@config(config=ConfigDict(arbitrary_types_allowed=True))
class ModelConfig:
"""Configuration for the model."""

View File

@@ -51,7 +51,6 @@ DummyOptions: TypeAlias = (
@config
@dataclass
class MultiModalConfig:
"""Controls the behavior of multimodal models."""

View File

@@ -6,7 +6,6 @@ from typing import Any, Literal, cast
from packaging.version import parse
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm import version
from vllm.config.utils import config
@@ -16,7 +15,6 @@ DetailedTraceModules = Literal["model", "worker", "all"]
@config
@dataclass
class ObservabilityConfig:
"""Configuration for observability - metrics and tracing."""

View File

@@ -3,12 +3,10 @@
import os
from collections.abc import Callable
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Literal
import torch
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self
@@ -50,7 +48,6 @@ All2AllBackend = Literal[
@config
@dataclass
class EPLBConfig:
"""Configuration for Expert Parallel Load Balancing (EP)."""
@@ -94,7 +91,6 @@ class EPLBConfig:
@config
@dataclass
class ParallelConfig:
"""Configuration for the distributed execution."""
@@ -715,6 +711,3 @@ class ParallelConfig:
)
return self
def replace(self, **kwargs) -> Self:
return replace(self, **kwargs)

View File

@@ -3,8 +3,6 @@
from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
@@ -19,7 +17,6 @@ TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
@config
@dataclass
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""

View File

@@ -5,7 +5,6 @@ import os
from typing import Any, Literal
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config
@@ -32,7 +31,6 @@ def _is_uri_path(path: str) -> bool:
@config
@dataclass
class ProfilerConfig:
"""Dataclass which contains profiler config for the engine."""

View File

@@ -6,7 +6,6 @@ from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config
@@ -24,7 +23,6 @@ SchedulerPolicy = Literal["fcfs", "priority"]
@config
@dataclass
class SchedulerConfig:
"""Scheduler configuration."""

View File

@@ -5,7 +5,6 @@ import ast
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.model import ModelConfig
@@ -55,7 +54,6 @@ SpeculativeMethod = Literal[
@config
@dataclass
class SpeculativeConfig:
"""Configuration for speculative decoding."""

View File

@@ -2,13 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@config
@dataclass
class SpeechToTextConfig:
"""Configuration for speech-to-text models."""

View File

@@ -4,7 +4,6 @@
from typing import Any, Literal
from pydantic import model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config
@@ -16,7 +15,6 @@ StructuredOutputsBackend = Literal[
@config
@dataclass
class StructuredOutputsConfig:
"""Dataclass which contains structured outputs config for the engine."""

View File

@@ -10,14 +10,17 @@ import json
import pathlib
import textwrap
from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from dataclasses import MISSING, Field, field, fields, is_dataclass
from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
import regex as re
import torch
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from pydantic.fields import Field as PydanticField
from pydantic.fields import FieldInfo
from typing_extensions import runtime_checkable
from typing_extensions import dataclass_transform, runtime_checkable
from vllm.logger import init_logger
@@ -29,23 +32,39 @@ else:
DataclassInstance = Any
ConfigType = type[DataclassInstance]
ConfigT = TypeVar("ConfigT", bound=ConfigType)
ConfigT = TypeVar("ConfigT", bound=DataclassInstance)
def config(cls: ConfigT) -> ConfigT:
"""
A decorator that ensures all fields in a dataclass have default values
and that each field has a docstring.
@dataclass_transform(field_specifiers=(PydanticField,))
def config(
cls: type[ConfigT] | None = None,
*,
config: ConfigDict | None = None,
**kwargs: Any,
) -> type[ConfigT] | Callable[[type[ConfigT]], type[ConfigT]]:
"""Decorator to create a pydantic dataclass with default config. The default config
for the dataclass forbids extra fields.
If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument
provided by `get_kwargs` will be
`pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the
`cli_arg` as a JSON string which gets validated by `pydantic`.
All config classes in vLLM should use this decorator.
Config validation is performed by the tools/pre_commit/validate_config.py
script, which is invoked during the pre-commit checks.
"""
return cls
Args:
cls: The class to decorate
config: The pydantic ConfigDict to use. If provided, it will be merged with
the default config.
**kwargs: Additional arguments to pass to pydantic.dataclass."""
# Extra fields are forbidden by default
merged_config = ConfigDict(extra="forbid")
if config is not None:
merged_config.update(config)
def decorator(cls):
return dataclass(cls, config=merged_config, **kwargs)
# Called with arguments: @config(config=...)
if cls is None:
return decorator
# Called without arguments: @config
return decorator(cls)
def get_field(cls: ConfigType, name: str) -> Field:
@@ -53,24 +72,46 @@ def get_field(cls: ConfigType, name: str) -> Field:
default factory fields in `EngineArgs`."""
if not is_dataclass(cls):
raise TypeError("The given class is not a dataclass.")
cls_fields = {f.name: f for f in fields(cls)}
if name not in cls_fields:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
named_field: Field = cls_fields[name]
if (default_factory := named_field.default_factory) is not MISSING:
return field(default_factory=default_factory)
if (default := named_field.default) is not MISSING:
if isinstance(default, FieldInfo):
# Handle pydantic.Field defaults
if default.default_factory is not None:
return field(default_factory=default.default_factory)
else:
default = default.default
return field(default=default)
try:
named_field = next(f for f in fields(cls) if f.name == name)
except StopIteration as e:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.") from e
raise ValueError(
f"{cls.__name__}.{name} must have a default value or default factory."
)
# The arguments to copy to the new field
default = named_field.default
default_factory = named_field.default_factory
init = named_field.init
# Handle pydantic.Field
if isinstance(default, FieldInfo):
if default.init is not None:
init = default.init
if default.default_factory is not None:
default_factory = cast(Callable[[], Any], default.default_factory)
default = MISSING
else:
default = default.default
if default is MISSING and default_factory is MISSING:
logger.warning_once(
"%s.%s has no default or default factory.", cls.__name__, name
)
return field(default=default, default_factory=default_factory, init=init)
def is_init_field(cls: ConfigType, name: str) -> bool:
return get_field(cls, name).init
def replace(dataclass_instance: ConfigT, /, **kwargs) -> ConfigT:
"""Like [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace),
but compatible with Pydantic dataclasses which use `pydantic.fields.Field` instead
of `dataclasses.field`"""
cls = type(dataclass_instance)
dataclass_dict = dataclass_instance.__dict__
dataclass_dict = {k: v for k, v in dataclass_dict.items() if is_init_field(cls, k)}
dataclass_dict.update(kwargs)
return cls(**dataclass_dict)
def getattr_iter(
@@ -172,10 +213,6 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
return out
def is_init_field(cls: ConfigType, name: str) -> bool:
return next(f for f in fields(cls) if f.name == name).init
@runtime_checkable
class SupportsHash(Protocol):
def compute_hash(self) -> str: ...

View File

@@ -9,7 +9,7 @@ import tempfile
import threading
import time
from contextlib import contextmanager
from dataclasses import is_dataclass, replace
from dataclasses import is_dataclass
from datetime import datetime
from enum import IntEnum
from functools import lru_cache
@@ -18,10 +18,8 @@ from typing import TYPE_CHECKING, Any, TypeVar, get_args
import torch
from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
@@ -41,9 +39,9 @@ from .observability import ObservabilityConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig
from .speculative import EagleModelTypes, SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config
from .utils import SupportsHash, config, replace
if TYPE_CHECKING:
from transformers import PretrainedConfig
@@ -187,8 +185,7 @@ OPTIMIZATION_LEVEL_TO_CONFIG = {
}
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@config(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
@@ -1395,14 +1392,6 @@ class VllmConfig:
path = self.compilation_config.debug_dump_path / append_path
return path
def replace(self, **kwargs):
"""
Replace attributes of the config, and 'recompute' the config.
dataclass.replace() calls __init__() and __post_init__(), source:
https://docs.python.org/3/library/dataclasses.html#dataclasses.replace
"""
return replace(self, **kwargs)
def __str__(self):
return (
f"model={self.model_config.model!r}, "

View File

@@ -13,8 +13,6 @@ from collections.abc import Sequence
from dataclasses import field
from typing import Any, Literal
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
@@ -69,7 +67,6 @@ class LoRAParserAction(argparse.Action):
@config
@dataclass
class FrontendArgs:
"""Arguments for the OpenAI-compatible frontend server."""

View File

@@ -13,9 +13,10 @@ def register_speculator(name):
@register_speculator("eagle3")
def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None:
"""
Apply Eagle-3 specific configuration transformations.
Apply Eagle-3 specific configuration transformations to the `dict` used to
construct the Transformers PreTrainedConfig.
Eagle-3 specific fields:
- draft_vocab_size: Size of the draft model's vocabulary
@@ -27,12 +28,14 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
predictions. This is the standard field used in Eagle3 checkpoints.
"""
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
pre_trained_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
if config_dict.get("target_hidden_size") is not None:
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True)
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"]
pre_trained_config["norm_before_residual"] = config_dict.get(
"norm_before_residual", True
)
pre_trained_config["architectures"] = ["Eagle3LlamaForCausalLM"]
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
"eagle_aux_hidden_state_layer_ids"
]

View File

@@ -24,13 +24,16 @@ class SpeculatorsConfig(PretrainedConfig):
"""Load speculators Eagle config and convert to vLLM format."""
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
vllm_config = cls.extract_vllm_speculative_config(config_dict)
vllm_config = cls.extract_transformers_pre_trained_config(config_dict)
return cls(**vllm_config)
@classmethod
def extract_vllm_speculative_config(
def extract_transformers_pre_trained_config(
cls, config_dict: dict[str, Any]
) -> dict[str, Any]:
"""
Extract standard Transformers PreTrainedConfig config from speculators config.
"""
speculators_model_type = config_dict.get("speculators_model_type")
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
raise ValueError(
@@ -38,15 +41,23 @@ class SpeculatorsConfig(PretrainedConfig):
"Please ensure you're loading a speculators-format model."
)
# Start with transformer layer configuration if present
pre_trained_config = config_dict.get("transformer_layer_config", {})
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, pre_trained_config=pre_trained_config)
return pre_trained_config
@classmethod
def extract_vllm_speculative_config(
cls, config_dict: dict[str, Any]
) -> dict[str, Any]:
"""Extract vLLM speculative config from speculators config."""
# validate fields
# TODO: @dsikka - use speculators pydantic model to validate
cls.validate_speculators_config(config_dict=config_dict)
# Convert from speculators config -> format that can be ingested by vLLM
vllm_config = cls.build_vllm_speculative_config(config_dict=config_dict)
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
return vllm_config
return cls.build_vllm_speculative_config(config_dict=config_dict)
@classmethod
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
@@ -101,14 +112,7 @@ class SpeculatorsConfig(PretrainedConfig):
)
# Build base vLLM speculative configuration
vllm_config = {
return {
"method": config_dict.get("speculators_model_type"),
"num_speculative_tokens": num_speculative_tokens,
"target_model": spec_config.get("verifier")["name_or_path"],
}
# Merge transformer layer configuration if present
transformer_config = config_dict.get("transformer_layer_config", {})
vllm_config.update(transformer_config)
return vllm_config

View File

@@ -4,7 +4,7 @@ from typing import Any
import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config import VllmConfig, get_layers_from_vllm_config, replace
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.model_loader import get_model
@@ -191,10 +191,12 @@ def create_vllm_config_for_draft_model(
old = target_model_vllm_config
assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
new_parallel_config = old_spec_config.draft_parallel_config.replace(
rank=old.parallel_config.rank
new_parallel_config = replace(
old_spec_config.draft_parallel_config,
rank=old.parallel_config.rank,
)
new: VllmConfig = old.replace(
new: VllmConfig = replace(
old,
quant_config=None, # quant_config is recomputed in __init__()
model_config=old_spec_config.draft_model_config,
parallel_config=new_parallel_config,