[Misc] Add placeholder module (#11501)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -48,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
runai_safetensors_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
@@ -1269,16 +1270,6 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
is_s3_path = is_s3(model_name_or_path)
|
||||
if is_s3_path:
|
||||
try:
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install Run:ai optional dependency "
|
||||
"to use the S3 capabilities. "
|
||||
"You can install it with: pip install vllm[runai]"
|
||||
) from err
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
safetensors_pattern = "*.safetensors"
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
@@ -19,9 +19,7 @@ from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
tensorizer_error_msg = None
|
||||
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
|
||||
|
||||
try:
|
||||
from tensorizer import (DecryptionParams, EncryptionParams,
|
||||
@@ -34,8 +32,19 @@ try:
|
||||
open_stream,
|
||||
mode=mode,
|
||||
) for mode in ("rb", "wb+"))
|
||||
except ImportError as e:
|
||||
tensorizer_error_msg = str(e)
|
||||
except ImportError:
|
||||
tensorizer = PlaceholderModule("tensorizer")
|
||||
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
|
||||
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
|
||||
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
|
||||
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
|
||||
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
|
||||
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
|
||||
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
|
||||
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
|
||||
|
||||
_read_stream = tensorizer.placeholder_attr("_read_stream")
|
||||
_write_stream = tensorizer.placeholder_attr("_write_stream")
|
||||
|
||||
__all__ = [
|
||||
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
|
||||
@@ -267,12 +276,6 @@ class TensorizerAgent:
|
||||
"""
|
||||
|
||||
def __init__(self, tensorizer_config: TensorizerConfig, vllm_config):
|
||||
if tensorizer_error_msg is not None:
|
||||
raise ImportError(
|
||||
"Tensorizer is not installed. Please install tensorizer "
|
||||
"to use this feature with `pip install vllm[tensorizer]`. "
|
||||
"Error message: {}".format(tensorizer_error_msg))
|
||||
|
||||
self.tensorizer_config = tensorizer_config
|
||||
self.tensorizer_args = (
|
||||
self.tensorizer_config._construct_tensorizer_args())
|
||||
|
||||
@@ -25,7 +25,15 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
from vllm.utils import PlaceholderModule, print_warning_once
|
||||
|
||||
try:
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
except ImportError:
|
||||
runai_model_streamer = PlaceholderModule(
|
||||
"runai_model_streamer") # type: ignore[assignment]
|
||||
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
|
||||
"SafetensorsStreamer")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -414,13 +422,6 @@ def runai_safetensors_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
try:
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install Run:ai optional dependency."
|
||||
"You can install it with: pip install vllm[runai]") from err
|
||||
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
with SafetensorsStreamer() as streamer:
|
||||
|
||||
Reference in New Issue
Block a user