[Feature] Use pydantic validation in lora.py and load.py configs (#26413)

Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
This commit is contained in:
Simon Danielsson
2025-10-09 11:38:33 +02:00
committed by GitHub
parent e6e898f95d
commit e4791438ed
4 changed files with 48 additions and 45 deletions

View File

@@ -2,9 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import field
from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@@ -64,7 +64,7 @@ class LoadConfig:
was quantized using torchao and saved using safetensors.
Needs torchao >= 0.14.0
"""
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
model_loader_extra_config: Union[dict, TensorizerConfig] = Field(
default_factory=dict
)
"""Extra config for model loader. This will be passed to the model loader
@@ -72,7 +72,9 @@ class LoadConfig:
device: Optional[str] = None
"""Device to which model weights will be loaded, default to
device_config.device"""
ignore_patterns: Optional[Union[list[str], str]] = None
ignore_patterns: Union[list[str], str] = Field(
default_factory=lambda: ["original/**/*"]
)
"""The list of patterns to ignore when loading the model. Default to
"original/**/*" to avoid repeated loading of llama's checkpoints."""
use_tqdm_on_load: bool = True
@@ -107,12 +109,18 @@ class LoadConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
self.load_format = self.load_format.lower()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
@field_validator("load_format", mode="after")
def _lowercase_load_format(cls, load_format: str) -> str:
return load_format.lower()
@field_validator("ignore_patterns", mode="after")
def _validate_ignore_patterns(
cls, ignore_patterns: Union[list[str], str]
) -> Union[list[str], str]:
if ignore_patterns != ["original/**/*"] and len(ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns,
ignore_patterns,
)
else:
self.ignore_patterns = ["original/**/*"]
return ignore_patterns