Improve configs - LoadConfig (#16422)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||
Optional, Protocol, Union)
|
||||
Optional, Protocol, TypeVar, Union)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@@ -45,6 +45,7 @@ from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||
random_uuid, resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
@@ -53,8 +54,11 @@ if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
|
||||
Config = TypeVar("Config", bound=DataclassInstance)
|
||||
else:
|
||||
QuantizationConfig = None
|
||||
Config = TypeVar("Config")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -159,7 +163,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||
return out
|
||||
|
||||
|
||||
def config(cls: type[Any]) -> type[Any]:
|
||||
def config(cls: type[Config]) -> type[Config]:
|
||||
"""
|
||||
A decorator that ensures all fields in a dataclass have default values
|
||||
and that each field has a docstring.
|
||||
@@ -1431,44 +1435,47 @@ class LoadFormat(str, enum.Enum):
|
||||
FASTSAFETENSORS = "fastsafetensors"
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
"tensorizer" will use CoreWeave's tensorizer library for
|
||||
fast weight loading.
|
||||
"bitsandbytes" will load nf4 type weights.
|
||||
"sharded_state" will load weights from pre-sharded checkpoint files,
|
||||
supporting efficient loading of tensor-parallel models.
|
||||
"gguf" will load weights from GGUF format files.
|
||||
"mistral" will load weights from consolidated safetensors files used
|
||||
by Mistral models.
|
||||
"runai_streamer" will load weights from RunAI streamer format files.
|
||||
model_loader_extra_config: The extra config for the model loader.
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
use_tqdm_on_load: Whether to enable tqdm for showing progress bar during
|
||||
loading. Default to True
|
||||
"""
|
||||
"""Configuration for loading the model weights."""
|
||||
|
||||
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
||||
load_format: Union[str, LoadFormat,
|
||||
"BaseModelLoader"] = LoadFormat.AUTO.value
|
||||
"""The format of the model weights to load:\n
|
||||
- "auto" will try to load the weights in the safetensors format and fall
|
||||
back to the pytorch bin format if safetensors format is not available.\n
|
||||
- "pt" will load the weights in the pytorch bin format.\n
|
||||
- "safetensors" will load the weights in the safetensors format.\n
|
||||
- "npcache" will load the weights in pytorch format and store a numpy cache
|
||||
to speed up the loading.\n
|
||||
- "dummy" will initialize the weights with random values, which is mainly
|
||||
for profiling.\n
|
||||
- "tensorizer" will use CoreWeave's tensorizer library for fast weight
|
||||
loading. See the Tensorize vLLM Model script in the Examples section for
|
||||
more information.\n
|
||||
- "runai_streamer" will load the Safetensors weights using Run:ai Model
|
||||
Streamer.\n
|
||||
- "bitsandbytes" will load the weights using bitsandbytes quantization.\n
|
||||
- "sharded_state" will load weights from pre-sharded checkpoint files,
|
||||
supporting efficient loading of tensor-parallel models.\n
|
||||
- "gguf" will load weights from GGUF format files (details specified in
|
||||
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
|
||||
- "mistral" will load weights from consolidated safetensors files used by
|
||||
Mistral models."""
|
||||
download_dir: Optional[str] = None
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(
|
||||
default_factory=dict)
|
||||
"""Directory to download and load the weights, default to the default
|
||||
cache directory of Hugging Face."""
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = None
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
corresponding to the chosen load_format. This should be a JSON string that
|
||||
will be parsed into a dictionary."""
|
||||
ignore_patterns: Optional[Union[list[str], str]] = None
|
||||
"""The list of patterns to ignore when loading the model. Default to
|
||||
"original/**/*" to avoid repeated loading of llama's checkpoints."""
|
||||
use_tqdm_on_load: bool = True
|
||||
"""Whether to enable tqdm for showing progress bar when loading model
|
||||
weights."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user