[Core] Loading model from S3 using RunAI Model Streamer as optional loader (#10192)

Signed-off-by: OmerD <omer@run.ai>
This commit is contained in:
omer-dayan
2024-12-20 18:46:24 +02:00
committed by GitHub
parent 7c7aa37c69
commit 995f56236b
13 changed files with 457 additions and 3 deletions

View File

@@ -29,6 +29,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.transformers_utils.utils import is_s3
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname)
@@ -256,6 +257,8 @@ class ModelConfig:
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)
# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
@@ -357,6 +360,39 @@ class ModelConfig:
self._verify_cuda_graph()
self._verify_bnb_config()
def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
"""
Pull the model config or tokenizer to a temporary
directory in case of S3.
Args:
model: The model name or path.
tokenizer: The tokenizer name or path.
"""
if is_s3(model) or is_s3(tokenizer):
try:
from vllm.transformers_utils.s3_utils import S3Model
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency "
"to use the S3 capabilities. "
"You can install it with: pip install vllm[runai]"
) from err
if is_s3(model):
self.s3_model = S3Model()
self.s3_model.pull_files(model, allow_pattern=["*config.json"])
self.model_weights = self.model
self.model = self.s3_model.dir
if is_s3(tokenizer):
self.s3_tokenizer = S3Model()
self.s3_tokenizer.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = self.s3_tokenizer.dir
def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
) -> Optional["MultiModalConfig"]:
@@ -1099,6 +1135,7 @@ class LoadFormat(str, enum.Enum):
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"
@dataclass