[Misc] Add placeholder module (#11501)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-26 21:12:51 +08:00
committed by GitHub
parent f57ee5650d
commit eec906d811
14 changed files with 143 additions and 100 deletions

View File

@@ -48,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import (
runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.transformers_utils.s3_utils import glob as s3_glob
from vllm.transformers_utils.utils import is_s3
from vllm.utils import is_pin_memory_available
@@ -1269,16 +1270,6 @@ class RunaiModelStreamerLoader(BaseModelLoader):
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
if is_s3_path:
try:
from vllm.transformers_utils.s3_utils import glob as s3_glob
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency "
"to use the S3 capabilities. "
"You can install it with: pip install vllm[runai]"
) from err
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME

View File

@@ -19,9 +19,7 @@ from vllm.engine.llm_engine import LLMEngine
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.utils import FlexibleArgumentParser
tensorizer_error_msg = None
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
try:
from tensorizer import (DecryptionParams, EncryptionParams,
@@ -34,8 +32,19 @@ try:
open_stream,
mode=mode,
) for mode in ("rb", "wb+"))
except ImportError as e:
tensorizer_error_msg = str(e)
except ImportError:
tensorizer = PlaceholderModule("tensorizer")
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
_read_stream = tensorizer.placeholder_attr("_read_stream")
_write_stream = tensorizer.placeholder_attr("_write_stream")
__all__ = [
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
@@ -267,12 +276,6 @@ class TensorizerAgent:
"""
def __init__(self, tensorizer_config: TensorizerConfig, vllm_config):
if tensorizer_error_msg is not None:
raise ImportError(
"Tensorizer is not installed. Please install tensorizer "
"to use this feature with `pip install vllm[tensorizer]`. "
"Error message: {}".format(tensorizer_error_msg))
self.tensorizer_config = tensorizer_config
self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args())

View File

@@ -25,7 +25,15 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
from vllm.utils import PlaceholderModule, print_warning_once
try:
from runai_model_streamer import SafetensorsStreamer
except ImportError:
runai_model_streamer = PlaceholderModule(
"runai_model_streamer") # type: ignore[assignment]
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
"SafetensorsStreamer")
logger = init_logger(__name__)
@@ -414,13 +422,6 @@ def runai_safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
try:
from runai_model_streamer import SafetensorsStreamer
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency."
"You can install it with: pip install vllm[runai]") from err
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
with SafetensorsStreamer() as streamer: