Validating Runai Model Streamer Integration with S3 Object Storage (#29320)
Signed-off-by: Noa Neria <noa@run.ai>
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor import UniProcExecutor
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
|
||||
# This is a dummy executor for patching in test_runai_model_streamer_s3.py.
|
||||
# We cannot use vllm_runner fixture here, because it spawns worker process.
|
||||
# The worker process reimports the patched entities, and the patch is not applied.
|
||||
class RunaiDummyExecutor(UniProcExecutor):
|
||||
def _init_executor(self) -> None:
|
||||
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
|
||||
|
||||
local_rank = 0
|
||||
rank = 0
|
||||
is_driver_worker = True
|
||||
|
||||
device_info = self.vllm_config.device_config.device.__str__().split(":")
|
||||
if len(device_info) > 1:
|
||||
local_rank = int(device_info[1])
|
||||
|
||||
worker_rpc_kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
wrapper_kwargs = {
|
||||
"vllm_config": self.vllm_config,
|
||||
}
|
||||
|
||||
self.driver_worker = WorkerWrapperBase(**wrapper_kwargs)
|
||||
|
||||
self.collective_rpc("init_worker", args=([worker_rpc_kwargs],))
|
||||
self.collective_rpc("init_device")
|
||||
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from runai_model_streamer.safetensors_streamer.streamer_mock import StreamerPatcher
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
from .conftest import RunaiDummyExecutor
|
||||
|
||||
load_format = "runai_streamer"
|
||||
test_model = "openai-community/gpt2"
|
||||
|
||||
|
||||
def test_runai_model_loader_download_files_s3_mocked_with_patch(
|
||||
vllm_runner,
|
||||
tmp_path: Path,
|
||||
monkeypatch,
|
||||
):
|
||||
patcher = StreamerPatcher(str(tmp_path))
|
||||
|
||||
test_mock_s3_model = "s3://my-mock-bucket/gpt2/"
|
||||
|
||||
# Download model from HF
|
||||
mock_model_dir = f"{tmp_path}/gpt2"
|
||||
snapshot_download(repo_id=test_model, local_dir=mock_model_dir)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"vllm.transformers_utils.runai_utils.runai_list_safetensors",
|
||||
patcher.shim_list_safetensors,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"vllm.transformers_utils.runai_utils.runai_pull_files",
|
||||
patcher.shim_pull_files,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"vllm.model_executor.model_loader.weight_utils.SafetensorsStreamer",
|
||||
patcher.create_mock_streamer,
|
||||
)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=test_mock_s3_model,
|
||||
load_format=load_format,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
executor = RunaiDummyExecutor(vllm_config)
|
||||
executor.driver_worker.load_model()
|
||||
Reference in New Issue
Block a user