Files
vllm/tests/runai_model_streamer_test/test_runai_utils.py

40 lines
1.3 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import glob
import os
import tempfile
import huggingface_hub.constants
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf)
from vllm.transformers_utils.runai_utils import (is_runai_obj_uri,
list_safetensors)
def test_is_runai_obj_uri():
assert is_runai_obj_uri("gs://some-gcs-bucket/path")
assert is_runai_obj_uri("s3://some-s3-bucket/path")
assert not is_runai_obj_uri("nfs://some-nfs-path")
def test_runai_list_safetensors_local():
with tempfile.TemporaryDirectory() as tmpdir:
huggingface_hub.constants.HF_HUB_OFFLINE = False
download_weights_from_hf("openai-community/gpt2",
allow_patterns=["*.safetensors", "*.json"],
cache_dir=tmpdir)
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
assert len(safetensors) > 0
parentdir = [
os.path.dirname(safetensor) for safetensor in safetensors
][0]
files = list_safetensors(parentdir)
assert len(safetensors) == len(files)
if __name__ == "__main__":
test_is_runai_obj_uri()
test_runai_list_safetensors_local()