[Perf] Validate @config in pre-commit instead of dynamically (#20200)

Signed-off-by: Lionel Villard <villard@us.ibm.com>
This commit is contained in:
Lionel Villard
2025-07-01 05:10:28 -04:00
committed by GitHub
parent 787b13389e
commit c05596f1a3
6 changed files with 220 additions and 57 deletions

View File

@@ -18,7 +18,7 @@ from functools import cached_property
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args, get_origin)
Protocol, TypeVar, Union, cast, get_args)
import regex as re
import torch
@@ -193,28 +193,10 @@ def config(cls: ConfigT) -> ConfigT:
(i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT`
requires custom construction from CLI (i.e. `CompilationConfig`), it can
have a `from_cli` method, which will be called instead.
Config validation is performed by the tools/validate_config.py
script, which is invoked during the pre-commit checks.
"""
if not is_dataclass(cls):
raise TypeError("The decorated class must be a dataclass.")
attr_docs = get_attr_docs(cls)
for f in fields(cls):
if f.init and f.default is MISSING and f.default_factory is MISSING:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a default value."
)
if f.name not in attr_docs:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a docstring.")
if get_origin(f.type) is Union:
args = get_args(f.type)
literal_args = [arg for arg in args if get_origin(arg) is Literal]
if len(literal_args) > 1:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must use a single "
"Literal type. Please use 'Literal[Literal1, Literal2]' "
"instead of 'Union[Literal1, Literal2]'.")
return cls
@@ -1798,7 +1780,7 @@ class ParallelConfig:
eplb_step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `eplb_window_size` steps will be used for rearranging experts.
"""