Improve configs - LoadConfig (#16422)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-11 21:27:27 +01:00
committed by GitHub
parent 71b9cde010
commit cd77382ac1
3 changed files with 95 additions and 96 deletions

View File

@@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, Union)
Optional, Protocol, TypeVar, Union)
import torch
from pydantic import BaseModel, Field, PrivateAttr
@@ -45,6 +45,7 @@ from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
random_uuid, resolve_obj_by_qualname)
if TYPE_CHECKING:
from _typeshed import DataclassInstance
from ray.util.placement_group import PlacementGroup
from vllm.executor.executor_base import ExecutorBase
@@ -53,8 +54,11 @@ if TYPE_CHECKING:
from vllm.model_executor.model_loader.loader import BaseModelLoader
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
Config = TypeVar("Config", bound=DataclassInstance)
else:
QuantizationConfig = None
Config = TypeVar("Config")
logger = init_logger(__name__)
@@ -159,7 +163,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
return out
def config(cls: type[Any]) -> type[Any]:
def config(cls: type[Config]) -> type[Config]:
"""
A decorator that ensures all fields in a dataclass have default values
and that each field has a docstring.
@@ -1431,44 +1435,47 @@ class LoadFormat(str, enum.Enum):
FASTSAFETENSORS = "fastsafetensors"
@config
@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
"sharded_state" will load weights from pre-sharded checkpoint files,
supporting efficient loading of tensor-parallel models.
"gguf" will load weights from GGUF format files.
"mistral" will load weights from consolidated safetensors files used
by Mistral models.
"runai_streamer" will load weights from RunAI streamer format files.
model_loader_extra_config: The extra config for the model loader.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
use_tqdm_on_load: Whether to enable tqdm for showing progress bar during
loading. Default to True
"""
"""Configuration for loading the model weights."""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
load_format: Union[str, LoadFormat,
"BaseModelLoader"] = LoadFormat.AUTO.value
"""The format of the model weights to load:\n
- "auto" will try to load the weights in the safetensors format and fall
back to the pytorch bin format if safetensors format is not available.\n
- "pt" will load the weights in the pytorch bin format.\n
- "safetensors" will load the weights in the safetensors format.\n
- "npcache" will load the weights in pytorch format and store a numpy cache
to speed up the loading.\n
- "dummy" will initialize the weights with random values, which is mainly
for profiling.\n
- "tensorizer" will use CoreWeave's tensorizer library for fast weight
loading. See the Tensorize vLLM Model script in the Examples section for
more information.\n
- "runai_streamer" will load the Safetensors weights using Run:ai Model
Streamer.\n
- "bitsandbytes" will load the weights using bitsandbytes quantization.\n
- "sharded_state" will load weights from pre-sharded checkpoint files,
supporting efficient loading of tensor-parallel models.\n
- "gguf" will load weights from GGUF format files (details specified in
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
- "mistral" will load weights from consolidated safetensors files used by
Mistral models."""
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict)
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
model_loader_extra_config: Optional[Union[str, dict]] = None
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format. This should be a JSON string that
will be parsed into a dictionary."""
ignore_patterns: Optional[Union[list[str], str]] = None
"""The list of patterns to ignore when loading the model. Default to
"original/**/*" to avoid repeated loading of llama's checkpoints."""
use_tqdm_on_load: bool = True
"""Whether to enable tqdm for showing progress bar when loading model
weights."""
def compute_hash(self) -> str:
"""