[Core] Loading model from S3 using RunAI Model Streamer as optional loader (#10192)

Signed-off-by: OmerD <omer@run.ai>
This commit is contained in:
omer-dayan
2024-12-20 18:46:24 +02:00
committed by GitHub
parent 7c7aa37c69
commit 995f56236b
13 changed files with 457 additions and 3 deletions

View File

@@ -29,6 +29,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.transformers_utils.utils import is_s3
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname)
@@ -256,6 +257,8 @@ class ModelConfig:
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)
# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
@@ -357,6 +360,39 @@ class ModelConfig:
self._verify_cuda_graph()
self._verify_bnb_config()
def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
"""
Pull the model config or tokenizer to a temporary
directory in case of S3.
Args:
model: The model name or path.
tokenizer: The tokenizer name or path.
"""
if is_s3(model) or is_s3(tokenizer):
try:
from vllm.transformers_utils.s3_utils import S3Model
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency "
"to use the S3 capabilities. "
"You can install it with: pip install vllm[runai]"
) from err
if is_s3(model):
self.s3_model = S3Model()
self.s3_model.pull_files(model, allow_pattern=["*config.json"])
self.model_weights = self.model
self.model = self.s3_model.dir
if is_s3(tokenizer):
self.s3_tokenizer = S3Model()
self.s3_tokenizer.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = self.s3_tokenizer.dir
def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
) -> Optional["MultiModalConfig"]:
@@ -1099,6 +1135,7 @@ class LoadFormat(str, enum.Enum):
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"
@dataclass

View File

@@ -316,6 +316,8 @@ class EngineArgs:
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples '
'section for more information.\n'
'* "runai_streamer" will load the Safetensors weights using Run:ai'
'Model Streamer \n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(

View File

@@ -45,9 +45,10 @@ from vllm.model_executor.model_loader.weight_utils import (
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.transformers_utils.utils import is_s3
from vllm.utils import is_pin_memory_available
@@ -1234,6 +1235,118 @@ class GGUFModelLoader(BaseModelLoader):
return model
class RunaiModelStreamerLoader(BaseModelLoader):
"""
Model loader that can load safetensors
files from local FS or S3 bucket.
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config
if ("concurrency" in extra_config
and isinstance(extra_config.get("concurrency"), int)):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
extra_config.get("concurrency"))
if ("memory_limit" in extra_config
and isinstance(extra_config.get("memory_limit"), int)):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
extra_config.get("memory_limit"))
runai_streamer_s3_endpoint = os.getenv(
'RUNAI_STREAMER_S3_ENDPOINT')
aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
if (runai_streamer_s3_endpoint is None
and aws_endpoint_url is not None):
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> List[str]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
if is_s3_path:
try:
from vllm.transformers_utils.s3_utils import glob as s3_glob
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency "
"to use the S3 capabilities. "
"You can install it with: pip install vllm[runai]"
) from err
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME
hf_folder = (model_name_or_path if
(is_local or is_s3_path) else download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[safetensors_pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
))
if is_s3_path:
hf_weights_files = s3_glob(path=hf_folder,
allow_pattern=[safetensors_pattern])
else:
hf_weights_files = glob.glob(
os.path.join(hf_folder, safetensors_pattern))
if not is_local and not is_s3_path:
download_safetensors_index_file_from_hf(
model_name_or_path, index_file, self.load_config.download_dir,
revision)
if not hf_weights_files:
raise RuntimeError(
f"Cannot find any safetensors model weights with "
f"`{model_name_or_path}`")
return hf_weights_files
def _get_weights_iterator(
self, model_or_path: str,
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(hf_weights_files)
def download_model(self, model_config: ModelConfig) -> None:
"""Download model if necessary"""
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
"""Perform streaming of the model to destination"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
model.load_weights(
self._get_weights_iterator(model_weights,
model_config.revision))
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model.eval()
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
@@ -1255,4 +1368,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)
return DefaultModelLoader(load_config)

View File

@@ -410,6 +410,30 @@ def safetensors_weights_iterator(
yield name, param
def runai_safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
try:
from runai_model_streamer import SafetensorsStreamer
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency."
"You can install it with: pip install vllm[runai]") from err
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
with SafetensorsStreamer() as streamer:
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
streamer.stream_file(st_file)
yield from streamer.get_tensors()
def pt_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:

View File

@@ -0,0 +1,146 @@
import fnmatch
import os
import shutil
import signal
import tempfile
from pathlib import Path
from typing import Optional
import boto3
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
return [
path for path in paths if any(
fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
return [
path for path in paths
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
]
def glob(s3=None,
path: str = "",
allow_pattern: Optional[list[str]] = None) -> list[str]:
"""
List full file names from S3 path and filter by allow pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full S3 paths allowed by the pattern
"""
if s3 is None:
s3 = boto3.client("s3")
bucket_name, _, paths = list_files(s3,
path=path,
allow_pattern=allow_pattern)
return [f"s3://{bucket_name}/{path}" for path in paths]
def list_files(
s3,
path: str,
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None
) -> tuple[str, str, list[str]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
Returns:
tuple[str, str, list[str]]: A tuple where:
- The first element is the bucket name
- The second element is string represent the bucket
and the prefix as a dir like string
- The third element is a list of files allowed or
disallowed by pattern
"""
parts = path.removeprefix('s3://').split('/')
prefix = '/'.join(parts[1:])
bucket_name = parts[0]
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
paths = [obj['Key'] for obj in objects.get('Contents', [])]
paths = _filter_ignore(paths, ["*/"])
if allow_pattern is not None:
paths = _filter_allow(paths, allow_pattern)
if ignore_pattern is not None:
paths = _filter_ignore(paths, ignore_pattern)
return bucket_name, prefix, paths
class S3Model:
"""
A class representing a S3 model mirrored into a temporary directory.
Attributes:
s3: S3 client.
dir: The temporary created directory.
Methods:
pull_files(): Pull model from S3 to the temporary directory.
"""
def __init__(self) -> None:
self.s3 = boto3.client('s3')
for sig in (signal.SIGINT, signal.SIGTERM):
existing_handler = signal.getsignal(sig)
signal.signal(sig, self._close_by_signal(existing_handler))
self.dir = tempfile.mkdtemp()
def __del__(self):
self._close()
def _close(self) -> None:
if os.path.exists(self.dir):
shutil.rmtree(self.dir)
def _close_by_signal(self, existing_handler=None):
def new_handler(signum, frame):
self._close()
if existing_handler:
existing_handler(signum, frame)
return new_handler
def pull_files(self,
s3_model_path: str = "",
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None) -> None:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name, base_dir, files = list_files(self.s3, s3_model_path,
allow_pattern,
ignore_pattern)
if len(files) == 0:
return
for file in files:
destination_file = self.dir + file.removeprefix(base_dir)
local_dir = Path(destination_file).parent
os.makedirs(local_dir, exist_ok=True)
self.s3.download_file(bucket_name, file, destination_file)

View File

@@ -3,6 +3,10 @@ from pathlib import Path
from typing import Union
def is_s3(model_or_path: str) -> bool:
return model_or_path.lower().startswith('s3://')
def check_gguf_file(model: Union[str, PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)