[Core] Loading model from S3 using RunAI Model Streamer as optional loader (#10192)
Signed-off-by: OmerD <omer@run.ai>
This commit is contained in:
@@ -45,9 +45,10 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||
get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
|
||||
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
|
||||
safetensors_weights_iterator)
|
||||
runai_safetensors_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@@ -1234,6 +1235,118 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
return model
|
||||
|
||||
|
||||
class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load safetensors
|
||||
files from local FS or S3 bucket.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
|
||||
if ("concurrency" in extra_config
|
||||
and isinstance(extra_config.get("concurrency"), int)):
|
||||
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
||||
extra_config.get("concurrency"))
|
||||
|
||||
if ("memory_limit" in extra_config
|
||||
and isinstance(extra_config.get("memory_limit"), int)):
|
||||
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
||||
extra_config.get("memory_limit"))
|
||||
|
||||
runai_streamer_s3_endpoint = os.getenv(
|
||||
'RUNAI_STREAMER_S3_ENDPOINT')
|
||||
aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
|
||||
if (runai_streamer_s3_endpoint is None
|
||||
and aws_endpoint_url is not None):
|
||||
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]) -> List[str]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
is_s3_path = is_s3(model_name_or_path)
|
||||
if is_s3_path:
|
||||
try:
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install Run:ai optional dependency "
|
||||
"to use the S3 capabilities. "
|
||||
"You can install it with: pip install vllm[runai]"
|
||||
) from err
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
safetensors_pattern = "*.safetensors"
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
hf_folder = (model_name_or_path if
|
||||
(is_local or is_s3_path) else download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[safetensors_pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
))
|
||||
|
||||
if is_s3_path:
|
||||
hf_weights_files = s3_glob(path=hf_folder,
|
||||
allow_pattern=[safetensors_pattern])
|
||||
else:
|
||||
hf_weights_files = glob.glob(
|
||||
os.path.join(hf_folder, safetensors_pattern))
|
||||
|
||||
if not is_local and not is_s3_path:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, index_file, self.load_config.download_dir,
|
||||
revision)
|
||||
|
||||
if not hf_weights_files:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any safetensors model weights with "
|
||||
f"`{model_name_or_path}`")
|
||||
|
||||
return hf_weights_files
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_or_path: str,
|
||||
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
||||
return runai_safetensors_weights_iterator(hf_weights_files)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download model if necessary"""
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
"""Perform streaming of the model to destination"""
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
|
||||
model_weights = model_config.model
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_weights,
|
||||
model_config.revision))
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
|
||||
@@ -1255,4 +1368,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
if load_config.load_format == LoadFormat.GGUF:
|
||||
return GGUFModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
|
||||
return RunaiModelStreamerLoader(load_config)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
|
||||
@@ -410,6 +410,30 @@ def safetensors_weights_iterator(
|
||||
yield name, param
|
||||
|
||||
|
||||
def runai_safetensors_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
try:
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install Run:ai optional dependency."
|
||||
"You can install it with: pip install vllm[runai]") from err
|
||||
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
with SafetensorsStreamer() as streamer:
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors using Runai Model Streamer",
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
streamer.stream_file(st_file)
|
||||
yield from streamer.get_tensors()
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
|
||||
Reference in New Issue
Block a user