[Core] Support model loader plugins (#21067)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import LoadConfig, LoadFormat, ModelConfig
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@@ -104,19 +104,19 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
use_safetensors = False
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == LoadFormat.AUTO:
|
||||
if load_format == "auto":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif (load_format == LoadFormat.SAFETENSORS
|
||||
or load_format == LoadFormat.FASTSAFETENSORS):
|
||||
elif (load_format == "safetensors"
|
||||
or load_format == "fastsafetensors"):
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == LoadFormat.MISTRAL:
|
||||
elif load_format == "mistral":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["consolidated*.safetensors"]
|
||||
index_file = "consolidated.safetensors.index.json"
|
||||
elif load_format == LoadFormat.PT:
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == LoadFormat.NPCACHE:
|
||||
elif load_format == "npcache":
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
@@ -178,7 +178,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path, source.revision, source.fall_back_to_pt,
|
||||
source.allow_patterns_overrides)
|
||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
||||
if self.load_config.load_format == "npcache":
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
@@ -189,7 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
elif use_safetensors:
|
||||
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
|
||||
if self.load_config.load_format == "fastsafetensors":
|
||||
weights_iterator = fastsafetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
|
||||
Reference in New Issue
Block a user