[Bugfix] Update Run:AI Model Streamer Loading Integration (#23845)
Signed-off-by: Omer Dayan (SW-GPU) <omer@run.ai> Signed-off-by: Peter Schuurman <psch@google.com> Co-authored-by: Omer Dayan (SW-GPU) <omer@run.ai> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
@@ -15,8 +14,8 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
runai_safetensors_weights_iterator)
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
from vllm.transformers_utils.runai_utils import (is_runai_obj_uri,
|
||||
list_safetensors)
|
||||
|
||||
|
||||
class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
@@ -53,27 +52,22 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
|
||||
is_s3_path = is_s3(model_name_or_path)
|
||||
is_object_storage_path = is_runai_obj_uri(model_name_or_path)
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
safetensors_pattern = "*.safetensors"
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
hf_folder = (model_name_or_path if
|
||||
(is_local or is_s3_path) else download_weights_from_hf(
|
||||
hf_folder = (model_name_or_path if (is_local or is_object_storage_path)
|
||||
else download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[safetensors_pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
))
|
||||
if is_s3_path:
|
||||
hf_weights_files = s3_glob(path=hf_folder,
|
||||
allow_pattern=[safetensors_pattern])
|
||||
else:
|
||||
hf_weights_files = glob.glob(
|
||||
os.path.join(hf_folder, safetensors_pattern))
|
||||
hf_weights_files = list_safetensors(path=hf_folder)
|
||||
|
||||
if not is_local and not is_s3_path:
|
||||
if not is_local and not is_object_storage_path:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, index_file, self.load_config.download_dir,
|
||||
revision)
|
||||
|
||||
Reference in New Issue
Block a user