[Feature] Use pydantic validation in lora.py and load.py configs (#26413)
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
This commit is contained in:
@@ -2,9 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
@@ -64,7 +64,7 @@ class LoadConfig:
|
||||
was quantized using torchao and saved using safetensors.
|
||||
Needs torchao >= 0.14.0
|
||||
"""
|
||||
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
|
||||
model_loader_extra_config: Union[dict, TensorizerConfig] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
@@ -72,7 +72,9 @@ class LoadConfig:
|
||||
device: Optional[str] = None
|
||||
"""Device to which model weights will be loaded, default to
|
||||
device_config.device"""
|
||||
ignore_patterns: Optional[Union[list[str], str]] = None
|
||||
ignore_patterns: Union[list[str], str] = Field(
|
||||
default_factory=lambda: ["original/**/*"]
|
||||
)
|
||||
"""The list of patterns to ignore when loading the model. Default to
|
||||
"original/**/*" to avoid repeated loading of llama's checkpoints."""
|
||||
use_tqdm_on_load: bool = True
|
||||
@@ -107,12 +109,18 @@ class LoadConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
self.load_format = self.load_format.lower()
|
||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||
@field_validator("load_format", mode="after")
|
||||
def _lowercase_load_format(cls, load_format: str) -> str:
|
||||
return load_format.lower()
|
||||
|
||||
@field_validator("ignore_patterns", mode="after")
|
||||
def _validate_ignore_patterns(
|
||||
cls, ignore_patterns: Union[list[str], str]
|
||||
) -> Union[list[str], str]:
|
||||
if ignore_patterns != ["original/**/*"] and len(ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
self.ignore_patterns,
|
||||
ignore_patterns,
|
||||
)
|
||||
else:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
|
||||
return ignore_patterns
|
||||
|
||||
Reference in New Issue
Block a user