[Feature] Use pydantic validation in lora.py and load.py configs (#26413)
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
|
||||
import regex as re
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -50,7 +51,14 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
if (default_factory := named_field.default_factory) is not MISSING:
|
||||
return field(default_factory=default_factory)
|
||||
if (default := named_field.default) is not MISSING:
|
||||
if isinstance(default, FieldInfo):
|
||||
# Handle pydantic.Field defaults
|
||||
if default.default_factory is not None:
|
||||
return field(default_factory=default.default_factory)
|
||||
else:
|
||||
default = default.default
|
||||
return field(default=default)
|
||||
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.{name} must have a default value or default factory."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user