[Core] Support model loader plugins (#21067)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@@ -26,13 +26,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
DetailedTraceModules, Device, DeviceConfig,
|
||||
DistributedExecutorBackend, GuidedDecodingBackend,
|
||||
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LoadFormat,
|
||||
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
|
||||
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
|
||||
get_field)
|
||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||
PoolerConfig, PrefixCachingHashAlgo, SchedulerConfig,
|
||||
SchedulerPolicy, SpeculativeConfig, TaskOption,
|
||||
TokenizerMode, VllmConfig, get_attr_docs, get_field)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.plugins import load_general_plugins
|
||||
@@ -47,10 +46,12 @@ from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
if TYPE_CHECKING:
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
else:
|
||||
ExecutorBase = Any
|
||||
QuantizationMethods = Any
|
||||
LoadFormats = Any
|
||||
UsageContext = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -276,7 +277,7 @@ class EngineArgs:
|
||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||
download_dir: Optional[str] = LoadConfig.download_dir
|
||||
load_format: str = LoadConfig.load_format
|
||||
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
||||
config_format: str = ModelConfig.config_format
|
||||
dtype: ModelDType = ModelConfig.dtype
|
||||
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
||||
@@ -547,9 +548,7 @@ class EngineArgs:
|
||||
title="LoadConfig",
|
||||
description=LoadConfig.__doc__,
|
||||
)
|
||||
load_group.add_argument("--load-format",
|
||||
choices=[f.value for f in LoadFormat],
|
||||
**load_kwargs["load_format"])
|
||||
load_group.add_argument("--load-format", **load_kwargs["load_format"])
|
||||
load_group.add_argument("--download-dir",
|
||||
**load_kwargs["download_dir"])
|
||||
load_group.add_argument("--model-loader-extra-config",
|
||||
@@ -864,10 +863,9 @@ class EngineArgs:
|
||||
|
||||
# NOTE: This is to allow model loading from S3 in CI
|
||||
if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
|
||||
and self.model in MODELS_ON_S3
|
||||
and self.load_format == LoadFormat.AUTO): # noqa: E501
|
||||
and self.model in MODELS_ON_S3 and self.load_format == "auto"):
|
||||
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
|
||||
self.load_format = LoadFormat.RUNAI_STREAMER
|
||||
self.load_format = "runai_streamer"
|
||||
|
||||
return ModelConfig(
|
||||
model=self.model,
|
||||
@@ -1299,7 +1297,7 @@ class EngineArgs:
|
||||
#############################################################
|
||||
# Unsupported Feature Flags on V1.
|
||||
|
||||
if self.load_format == LoadFormat.SHARDED_STATE.value:
|
||||
if self.load_format == "sharded_state":
|
||||
_raise_or_fallback(
|
||||
feature_name=f"--load_format {self.load_format}",
|
||||
recommend_to_remove=False)
|
||||
|
||||
Reference in New Issue
Block a user