diff --git a/examples/others/tensorize_vllm_model.py b/examples/others/tensorize_vllm_model.py index 112332295..64a6c42ae 100644 --- a/examples/others/tensorize_vllm_model.py +++ b/examples/others/tensorize_vllm_model.py @@ -4,6 +4,7 @@ import argparse import dataclasses import json +import logging import os import uuid @@ -15,9 +16,13 @@ from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, tensorize_lora_adapter, tensorize_vllm_model, + tensorizer_kwargs_arg, ) from vllm.utils import FlexibleArgumentParser +logger = logging.getLogger() + + # yapf conflicts with isort for this docstring # yapf: disable """ @@ -119,7 +124,7 @@ vllm serve \ """ -def parse_args(): +def get_parser(): parser = FlexibleArgumentParser( description="An example script that can be used to serialize and " "deserialize vLLM models. These models " @@ -135,13 +140,13 @@ def parse_args(): required=False, help="Path to a LoRA adapter to " "serialize along with model tensors. This can then be deserialized " - "along with the model by passing a tensorizer_config kwarg to " - "LoRARequest with type TensorizerConfig. See the docstring for this " - "for a usage example." - + "along with the model by instantiating a TensorizerConfig object, " + "creating a dict from it with TensorizerConfig.to_serializable(), " + "and passing it to LoRARequest's initializer with the kwarg " + "tensorizer_config_dict." ) - subparsers = parser.add_subparsers(dest='command') + subparsers = parser.add_subparsers(dest='command', required=True) serialize_parser = subparsers.add_parser( 'serialize', help="Serialize a model to `--serialized-directory`") @@ -171,6 +176,14 @@ def parse_args(): "where `suffix` is given by `--suffix` or a random UUID if not " "provided.") + serialize_parser.add_argument( + "--serialization-kwargs", + type=tensorizer_kwargs_arg, + required=False, + help=("A JSON string containing additional keyword arguments to " + "pass to Tensorizer's TensorSerializer during " + "serialization.")) + serialize_parser.add_argument( "--keyfile", type=str, @@ -186,9 +199,17 @@ def parse_args(): deserialize_parser.add_argument( "--path-to-tensors", type=str, - required=True, + required=False, help="The local path or S3 URI to the model tensors to deserialize. ") + deserialize_parser.add_argument( + "--serialized-directory", + type=str, + required=False, + help="Directory with model artifacts for loading. Assumes a " + "model.tensors file exists therein. Can supersede " + "--path-to-tensors.") + deserialize_parser.add_argument( "--keyfile", type=str, @@ -196,11 +217,27 @@ def parse_args(): help=("Path to a binary key to use to decrypt the model weights," " if the model was serialized with encryption")) + deserialize_parser.add_argument( + "--deserialization-kwargs", + type=tensorizer_kwargs_arg, + required=False, + help=("A JSON string containing additional keyword arguments to " + "pass to Tensorizer's `TensorDeserializer` during " + "deserialization.")) + TensorizerArgs.add_cli_args(deserialize_parser) - return parser.parse_args() - + return parser +def merge_extra_config_with_tensorizer_config(extra_cfg: dict, + cfg: TensorizerConfig): + for k, v in extra_cfg.items(): + if hasattr(cfg, k): + setattr(cfg, k, v) + logger.info( + "Updating TensorizerConfig with %s from " + "--model-loader-extra-config provided", k + ) def deserialize(args, tensorizer_config): if args.lora_path: @@ -230,7 +267,8 @@ def deserialize(args, tensorizer_config): lora_request=LoRARequest("sql-lora", 1, args.lora_path, - tensorizer_config = tensorizer_config) + tensorizer_config_dict = tensorizer_config + .to_serializable()) ) ) else: @@ -243,7 +281,8 @@ def deserialize(args, tensorizer_config): def main(): - args = parse_args() + parser = get_parser() + args = parser.parse_args() s3_access_key_id = (getattr(args, 's3_access_key_id', None) or os.environ.get("S3_ACCESS_KEY_ID", None)) @@ -265,13 +304,24 @@ def main(): else: keyfile = None + extra_config = {} if args.model_loader_extra_config: - config = json.loads(args.model_loader_extra_config) - tensorizer_args = \ - TensorizerConfig(**config)._construct_tensorizer_args() - tensorizer_args.tensorizer_uri = args.path_to_tensors - else: - tensorizer_args = None + extra_config = json.loads(args.model_loader_extra_config) + + + tensorizer_dir = (args.serialized_directory or + extra_config.get("tensorizer_dir")) + tensorizer_uri = (getattr(args, "path_to_tensors", None) + or extra_config.get("tensorizer_uri")) + + if tensorizer_dir and tensorizer_uri: + parser.error("--serialized-directory and --path-to-tensors " + "cannot both be provided") + + if not tensorizer_dir and not tensorizer_uri: + parser.error("Either --serialized-directory or --path-to-tensors " + "must be provided") + if args.command == "serialize": eng_args_dict = {f.name: getattr(args, f.name) for f in @@ -281,7 +331,7 @@ def main(): argparse.Namespace(**eng_args_dict) ) - input_dir = args.serialized_directory.rstrip('/') + input_dir = tensorizer_dir.rstrip('/') suffix = args.suffix if args.suffix else uuid.uuid4().hex base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" if engine_args.tensor_parallel_size > 1: @@ -292,21 +342,29 @@ def main(): tensorizer_config = TensorizerConfig( tensorizer_uri=model_path, encryption_keyfile=keyfile, - **credentials) + serialization_kwargs=args.serialization_kwargs or {}, + **credentials + ) if args.lora_path: tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir tensorize_lora_adapter(args.lora_path, tensorizer_config) + merge_extra_config_with_tensorizer_config(extra_config, + tensorizer_config) tensorize_vllm_model(engine_args, tensorizer_config) elif args.command == "deserialize": - if not tensorizer_args: - tensorizer_config = TensorizerConfig( - tensorizer_uri=args.path_to_tensors, - encryption_keyfile = keyfile, - **credentials - ) + tensorizer_config = TensorizerConfig( + tensorizer_uri=args.path_to_tensors, + tensorizer_dir=args.serialized_directory, + encryption_keyfile=keyfile, + deserialization_kwargs=args.deserialization_kwargs or {}, + **credentials + ) + + merge_extra_config_with_tensorizer_config(extra_config, + tensorizer_config) deserialize(args, tensorizer_config) else: raise ValueError("Either serialize or deserialize must be specified.") diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 0bade084f..d8bd031f1 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -1,6 +1,6 @@ # testing pytest -tensorizer>=2.9.0 +tensorizer==2.10.1 pytest-forked pytest-asyncio pytest-rerunfailures diff --git a/requirements/rocm.txt b/requirements/rocm.txt index d33021fc7..988329c3a 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -11,7 +11,7 @@ datasets ray>=2.10.0,<2.45.0 peft pytest-asyncio -tensorizer>=2.9.0 +tensorizer==2.10.1 packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 diff --git a/requirements/test.in b/requirements/test.in index 5f8b97a0e..907d90201 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -1,6 +1,6 @@ # testing pytest -tensorizer>=2.9.0 +tensorizer==2.10.1 pytest-forked pytest-asyncio pytest-rerunfailures diff --git a/requirements/test.txt b/requirements/test.txt index f6f599df7..2f3ccc4f6 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -739,7 +739,7 @@ tenacity==9.0.0 # via # lm-eval # plotly -tensorizer==2.9.0 +tensorizer==2.10.1 # via -r requirements/test.in threadpoolctl==3.5.0 # via scikit-learn diff --git a/setup.py b/setup.py index ea7cd0169..9200c6cef 100644 --- a/setup.py +++ b/setup.py @@ -689,7 +689,7 @@ setup( install_requires=get_requirements(), extras_require={ "bench": ["pandas", "datasets"], - "tensorizer": ["tensorizer>=2.9.0"], + "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.1.10"], "runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], "audio": ["librosa", "soundfile"], # Required for audio processing diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/test_tensorizer_entrypoint.py index e14315035..4bf379850 100644 --- a/tests/entrypoints/openai/test_tensorizer_entrypoint.py +++ b/tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc -import json +import os import tempfile import openai @@ -58,18 +58,20 @@ def tensorize_model_and_lora(tmp_dir, model_uri): @pytest.fixture(scope="module") def server(model_uri, tensorize_model_and_lora): - model_loader_extra_config = { - "tensorizer_uri": model_uri, - } + # In this case, model_uri is a directory with a model.tensors + # file and all necessary model artifacts, particularly a + # HF `config.json` file. In this case, Tensorizer can infer the + # `TensorizerConfig` so --model-loader-extra-config can be completely + # omitted. ## Start OpenAI API server args = [ - "--load-format", "tensorizer", "--device", "cuda", - "--model-loader-extra-config", - json.dumps(model_loader_extra_config), "--enable-lora" + "--load-format", "tensorizer", "--served-model-name", MODEL_NAME, + "--enable-lora" ] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + model_dir = os.path.dirname(model_uri) + with RemoteOpenAIServer(model_dir, args) as remote_server: yield remote_server diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 3ac3b80ec..91afa42fa 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -169,7 +169,8 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size", str(tp_size), "serialize", "--serialized-directory", - str(tmp_path), "--suffix", suffix + str(tmp_path), "--suffix", suffix, "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}' ], check=True, capture_output=True, @@ -195,7 +196,7 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, tensor_parallel_size=2, max_loras=2) - tensorizer_config_dict = tensorizer_config.to_dict() + tensorizer_config_dict = tensorizer_config.to_serializable() print("lora adapter created") assert do_sample(loaded_vllm_model, diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index cd59d579e..18aa4c88c 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -1,9 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable + import pytest +from vllm import LLM, EngineArgs from vllm.distributed import cleanup_dist_env_and_memory +from vllm.model_executor.model_loader import tensorizer as tensorizer_mod from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.executor.abstract import UniProcExecutor +from vllm.worker.worker_base import WorkerWrapperBase + +MODEL_REF = "facebook/opt-125m" + + +@pytest.fixture() +def model_ref(): + return MODEL_REF + + +@pytest.fixture(autouse=True) +def allow_insecure_serialization(monkeypatch): + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") @pytest.fixture(autouse=True) @@ -11,7 +30,73 @@ def cleanup(): cleanup_dist_env_and_memory(shutdown_ray=True) +@pytest.fixture() +def just_serialize_model_tensors(model_ref, monkeypatch, tmp_path): + + def noop(*args, **kwargs): + return None + + args = EngineArgs(model=model_ref) + tc = TensorizerConfig(tensorizer_uri=f"{tmp_path}/model.tensors") + + monkeypatch.setattr(tensorizer_mod, "serialize_extra_artifacts", noop) + + tensorizer_mod.tensorize_vllm_model(args, tc) + yield tmp_path + + @pytest.fixture(autouse=True) def tensorizer_config(): config = TensorizerConfig(tensorizer_uri="vllm") return config + + +@pytest.fixture() +def model_path(model_ref, tmp_path): + yield tmp_path / model_ref / "model.tensors" + + +def assert_from_collective_rpc(engine: LLM, closure: Callable, + closure_kwargs: dict): + res = engine.collective_rpc(method=closure, kwargs=closure_kwargs) + return all(res) + + +# This is an object pulled from tests/v1/engine/test_engine_core.py +# Modified to strip the `load_model` method from its `_init_executor` +# method. It's purely used as a dummy utility to run methods that test +# Tensorizer functionality +class DummyExecutor(UniProcExecutor): + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, + rpc_rank=0) + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + local_rank = 0 + # set local rank as the device index if specified + device_info = self.vllm_config.device_config.device.__str__().split( + ":") + if len(device_info) > 1: + local_rank = int(device_info[1]) + rank = 0 + is_driver_worker = True + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + self.collective_rpc("init_worker", args=([kwargs], )) + self.collective_rpc("init_device") + + @property + def max_concurrent_batches(self) -> int: + return 2 + + def shutdown(self): + if hasattr(self, 'thread_pool'): + self.thread_pool.shutdown(wait=False) diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index c97f5968d..9fe230512 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -1,36 +1,51 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio import gc +import json import os import pathlib import subprocess +import sys +from typing import Any import pytest import torch -from vllm import SamplingParams +import vllm.model_executor.model_loader.tensorizer +from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs -# yapf conflicts with isort for this docstring # yapf: disable from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, TensorSerializer, is_vllm_tensorized, open_stream, tensorize_vllm_model) +from vllm.model_executor.model_loader.tensorizer_loader import ( + BLACKLISTED_TENSORIZER_ARGS) # yapf: enable from vllm.utils import PlaceholderModule -from ..utils import VLLM_PATH +from ..utils import VLLM_PATH, RemoteOpenAIServer +from .conftest import DummyExecutor, assert_from_collective_rpc try: + import tensorizer from tensorizer import EncryptionParams except ImportError: tensorizer = PlaceholderModule("tensorizer") # type: ignore[assignment] EncryptionParams = tensorizer.placeholder_attr("EncryptionParams") + +class TensorizerCaughtError(Exception): + pass + + EXAMPLES_PATH = VLLM_PATH / "examples" +pytest_plugins = "pytest_asyncio", + prompts = [ "Hello, my name is", "The president of the United States is", @@ -40,9 +55,37 @@ prompts = [ # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) -model_ref = "facebook/opt-125m" -tensorize_model_for_testing_script = os.path.join( - os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py") + +def patch_init_and_catch_error(self, obj, method_name, + expected_error: type[Exception]): + original = getattr(obj, method_name, None) + if original is None: + raise ValueError("Method '{}' not found.".format(method_name)) + + def wrapper(*args, **kwargs): + try: + return original(*args, **kwargs) + except expected_error as err: + raise TensorizerCaughtError from err + + setattr(obj, method_name, wrapper) + + self.load_model() + + +def assert_specific_tensorizer_error_is_raised( + executor, + obj: Any, + method_name: str, + expected_error: type[Exception], +): + with pytest.raises(TensorizerCaughtError): + executor.collective_rpc(patch_init_and_catch_error, + args=( + obj, + method_name, + expected_error, + )) def is_curl_installed(): @@ -81,11 +124,10 @@ def test_can_deserialize_s3(vllm_runner): @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_deserialized_encrypted_vllm_model_has_same_outputs( - vllm_runner, tmp_path): + model_ref, vllm_runner, tmp_path, model_path): args = EngineArgs(model=model_ref) with vllm_runner(model_ref) as vllm_model: - model_path = tmp_path / (model_ref + ".tensors") - key_path = tmp_path / (model_ref + ".key") + key_path = tmp_path / model_ref / "model.key" write_keyfile(key_path) outputs = vllm_model.generate(prompts, sampling_params) @@ -111,9 +153,9 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, - tmp_path): + tmp_path, model_ref, + model_path): with hf_runner(model_ref) as hf_model: - model_path = tmp_path / (model_ref + ".tensors") max_tokens = 50 outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) with open_stream(model_path, "wb+") as stream: @@ -123,7 +165,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, with vllm_runner(model_ref, load_format="tensorizer", model_loader_extra_config=TensorizerConfig( - tensorizer_uri=model_path, + tensorizer_uri=str(model_path), num_readers=1, )) as loaded_hf_model: deserialized_outputs = loaded_hf_model.generate_greedy( @@ -132,7 +174,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, assert outputs == deserialized_outputs -def test_load_without_tensorizer_load_format(vllm_runner, capfd): +def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref): model = None try: model = vllm_runner( @@ -150,7 +192,8 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd): torch.cuda.empty_cache() -def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd): +def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, + model_ref): model = None try: model = vllm_runner( @@ -208,7 +251,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( outputs = base_model.generate(prompts, sampling_params) # load model with two shards and serialize with encryption - model_path = str(tmp_path / (model_ref + "-%02d.tensors")) + model_path = str(tmp_path / model_ref / "model-%02d.tensors") key_path = tmp_path / (model_ref + ".key") tensorizer_config = TensorizerConfig( @@ -242,13 +285,12 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( @pytest.mark.flaky(reruns=3) -def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): +def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner, + tmp_path, model_path): gc.collect() torch.cuda.empty_cache() - model_ref = "facebook/opt-125m" - model_path = tmp_path / (model_ref + ".tensors") config = TensorizerConfig(tensorizer_uri=str(model_path)) - args = EngineArgs(model=model_ref, device="cuda") + args = EngineArgs(model=model_ref) with vllm_runner(model_ref) as vllm_model: outputs = vllm_model.generate(prompts, sampling_params) @@ -264,3 +306,243 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): # noqa: E501 assert outputs == deserialized_outputs + + +def test_load_with_just_model_tensors(just_serialize_model_tensors, model_ref): + # For backwards compatibility, ensure Tensorizer can be still be loaded + # for inference by passing the model reference name, not a local/S3 dir, + # and the location of the model tensors + + model_dir = just_serialize_model_tensors + + extra_config = {"tensorizer_uri": f"{model_dir}/model.tensors"} + + ## Start OpenAI API server + args = [ + "--load-format", + "tensorizer", + "--model-loader-extra-config", + json.dumps(extra_config), + ] + + with RemoteOpenAIServer(model_ref, args): + # This test only concerns itself with being able to load the model + # and successfully initialize the server + pass + + +def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path): + + serialization_params = { + "limit_cpu_concurrency": 2, + } + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + llm = LLM(model=model_ref, ) + + def serialization_test(self, *args, **kwargs): + # This is performed in the ephemeral worker process, so monkey-patching + # will actually work, and cleanup is guaranteed so don't + # need to reset things + + original_dict = serialization_params + to_compare = {} + + original = tensorizer.serialization.TensorSerializer.__init__ + + def tensorizer_serializer_wrapper(self, *args, **kwargs): + nonlocal to_compare + to_compare = kwargs.copy() + return original(self, *args, **kwargs) + + tensorizer.serialization.TensorSerializer.__init__ = ( + tensorizer_serializer_wrapper) + + tensorizer_config = TensorizerConfig(**kwargs["tensorizer_config"]) + self.save_tensorized_model(tensorizer_config=tensorizer_config, ) + return to_compare | original_dict == to_compare + + kwargs = {"tensorizer_config": config.to_serializable()} + + assert assert_from_collective_rpc(llm, serialization_test, kwargs) + + +def test_assert_deserialization_kwargs_passed_to_tensor_deserializer( + tmp_path, capfd): + + deserialization_kwargs = { + "num_readers": "bar", # illegal value + } + + serialization_params = { + "limit_cpu_concurrency": 2, + } + + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + + args = EngineArgs(model=model_ref) + tensorize_vllm_model(args, config) + + loader_tc = TensorizerConfig( + tensorizer_uri=str(model_path), + deserialization_kwargs=deserialization_kwargs, + ) + + engine_args = EngineArgs( + model="facebook/opt-125m", + load_format="tensorizer", + model_loader_extra_config=loader_tc.to_serializable(), + ) + + vllm_config = engine_args.create_engine_config() + executor = DummyExecutor(vllm_config) + + assert_specific_tensorizer_error_is_raised( + executor, + tensorizer.serialization.TensorDeserializer, + "__init__", + TypeError, + ) + + +def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): + + deserialization_kwargs = { + "num_readers": 1, + } + + serialization_params = { + "limit_cpu_concurrency": 2, + } + + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + + args = EngineArgs(model=model_ref) + tensorize_vllm_model(args, config) + + stream_kwargs = {"mode": "foo"} + + loader_tc = TensorizerConfig( + tensorizer_uri=str(model_path), + deserialization_kwargs=deserialization_kwargs, + stream_kwargs=stream_kwargs, + ) + + engine_args = EngineArgs( + model="facebook/opt-125m", + load_format="tensorizer", + model_loader_extra_config=loader_tc.to_serializable(), + ) + + vllm_config = engine_args.create_engine_config() + executor = DummyExecutor(vllm_config) + + assert_specific_tensorizer_error_is_raised( + executor, + vllm.model_executor.model_loader.tensorizer, + "open_stream", + ValueError, + ) + + +@pytest.mark.asyncio +async def test_serialize_and_serve_entrypoints(tmp_path): + model_ref = "facebook/opt-125m" + + suffix = "test" + try: + result = subprocess.run([ + sys.executable, + f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", + model_ref, "serialize", "--serialized-directory", + str(tmp_path), "--suffix", suffix, "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}' + ], + check=True, + capture_output=True, + text=True) + except subprocess.CalledProcessError as e: + print("Tensorizing failed.") + print("STDOUT:\n", e.stdout) + print("STDERR:\n", e.stderr) + raise + + assert "Successfully serialized" in result.stdout + + # Next, try to serve with vllm serve + model_uri = tmp_path / "vllm" / model_ref / suffix / "model.tensors" + + model_loader_extra_config = { + "tensorizer_uri": str(model_uri), + "stream_kwargs": { + "force_http": False, + }, + "deserialization_kwargs": { + "verify_hash": True, + "num_readers": 8, + } + } + + cmd = [ + "-m", "vllm.entrypoints.cli.main", "serve", "--host", "localhost", + "--load-format", "tensorizer", model_ref, + "--model-loader-extra-config", + json.dumps(model_loader_extra_config, indent=2) + ] + + proc = await asyncio.create_subprocess_exec( + sys.executable, + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + + assert proc.stdout is not None + fut = proc.stdout.readuntil(b"Application startup complete.") + + try: + await asyncio.wait_for(fut, 180) + except asyncio.TimeoutError: + pytest.fail("Server did not start successfully") + finally: + proc.terminate() + await proc.communicate() + + +@pytest.mark.parametrize("illegal_value", BLACKLISTED_TENSORIZER_ARGS) +def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, + illegal_value): + + serialization_params = { + "limit_cpu_concurrency": 2, + } + + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + + args = EngineArgs(model=model_ref) + tensorize_vllm_model(args, config) + + loader_tc = {"tensorizer_uri": str(model_path), illegal_value: "foo"} + + try: + vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=loader_tc, + ) + except RuntimeError: + out, err = capfd.readouterr() + combined_output = out + err + assert (f"ValueError: {illegal_value} is not an allowed " + f"Tensorizer argument.") in combined_output diff --git a/vllm/config.py b/vllm/config.py index bac18e817..90cf885a4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -686,8 +686,11 @@ class ModelConfig: # If tokenizer is same as model, download to same directory if model == tokenizer: - s3_model.pull_files( - model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + s3_model.pull_files(model, + ignore_pattern=[ + "*.pt", "*.safetensors", "*.bin", + "*.tensors" + ]) self.tokenizer = s3_model.dir return @@ -695,7 +698,8 @@ class ModelConfig: if is_s3(tokenizer): s3_tokenizer = S3Model() s3_tokenizer.pull_files( - model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + model, + ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"]) self.tokenizer = s3_tokenizer.dir def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a497e3c8e..0c4fae1dd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -58,7 +58,8 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: def _parse_type(val: str) -> T: try: - if return_type is json.loads and not re.match("^{.*}$", val): + if return_type is json.loads and not re.match( + r"(?s)^\s*{.*}\s*$", val): return cast(T, nullable_kvs(val)) return return_type(val) except ValueError as e: @@ -80,7 +81,7 @@ def optional_type( def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: - if not re.match("^{.*}$", val): + if not re.match(r"(?s)^\s*{.*}\s*$", val): return str(val) return optional_type(json.loads)(val) @@ -1001,11 +1002,42 @@ class EngineArgs: override_attention_dtype=self.override_attention_dtype, ) + def valid_tensorizer_config_provided(self) -> bool: + """ + Checks if a parseable TensorizerConfig was passed to + self.model_loader_extra_config. It first checks if the config passed + is a dict or a TensorizerConfig object directly, and if the latter is + true (by checking that the object has TensorizerConfig's + .to_serializable() method), converts it in to a serializable dict + format + """ + if self.model_loader_extra_config: + if hasattr(self.model_loader_extra_config, "to_serializable"): + self.model_loader_extra_config = ( + self.model_loader_extra_config.to_serializable()) + for allowed_to_pass in ["tensorizer_uri", "tensorizer_dir"]: + try: + self.model_loader_extra_config[allowed_to_pass] + return False + except KeyError: + pass + return True + def create_load_config(self) -> LoadConfig: if self.quantization == "bitsandbytes": self.load_format = "bitsandbytes" + if (self.load_format == "tensorizer" + and self.valid_tensorizer_config_provided()): + logger.info("Inferring Tensorizer args from %s", self.model) + self.model_loader_extra_config = {"tensorizer_dir": self.model} + else: + logger.info( + "Using Tensorizer args from --model-loader-extra-config. " + "Note that you can now simply pass the S3 directory in the " + "model tag instead of providing the JSON string.") + return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 9e1ed3a77..bff4e9125 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -245,9 +245,10 @@ class LoRAModel(AdapterModel): lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir, "adapter_model.tensors") tensorizer_args = tensorizer_config._construct_tensorizer_args() - tensors = TensorDeserializer(lora_tensor_path, - dtype=tensorizer_config.dtype, - **tensorizer_args.deserializer_params) + tensors = TensorDeserializer( + lora_tensor_path, + dtype=tensorizer_config.dtype, + **tensorizer_args.deserialization_kwargs) check_unexpected_modules(tensors) elif os.path.isfile(lora_tensor_path): diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index a20d73f0f..e748a4a88 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -106,7 +106,7 @@ class PEFTHelper: "adapter_config.json") with open_stream(lora_config_path, mode="rb", - **tensorizer_args.stream_params) as f: + **tensorizer_args.stream_kwargs) as f: config = json.load(f) logger.info("Successfully deserialized LoRA config from %s", diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 1c14d55fc..ff101b664 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -5,18 +5,18 @@ import argparse import contextlib import contextvars import dataclasses -import io import json import os +import tempfile import threading import time -from collections.abc import Generator -from dataclasses import dataclass -from functools import partial -from typing import TYPE_CHECKING, Any, BinaryIO, Optional, Union +from collections.abc import Generator, MutableMapping +from dataclasses import asdict, dataclass, field, fields +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union import regex as re import torch +from huggingface_hub import snapshot_download from torch import nn from torch.utils._python_dispatch import TorchDispatchMode from transformers import PretrainedConfig @@ -39,10 +39,6 @@ try: from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) - _read_stream, _write_stream = (partial( - open_stream, - mode=mode, - ) for mode in ("rb", "wb+")) except ImportError: tensorizer = PlaceholderModule("tensorizer") DecryptionParams = tensorizer.placeholder_attr("DecryptionParams") @@ -54,9 +50,6 @@ except ImportError: get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage") no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor") - _read_stream = tensorizer.placeholder_attr("_read_stream") - _write_stream = tensorizer.placeholder_attr("_write_stream") - __all__ = [ 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', @@ -66,6 +59,23 @@ __all__ = [ logger = init_logger(__name__) +def is_valid_deserialization_uri(uri: Optional[str]) -> bool: + if uri: + scheme = uri.lower().split("://")[0] + return scheme in {"s3", "http", "https"} or os.path.exists(uri) + return False + + +def tensorizer_kwargs_arg(value): + loaded = json.loads(value) + if not isinstance(loaded, dict): + raise argparse.ArgumentTypeError( + f"Not deserializable to dict: {value}. serialization_kwargs and " + f"deserialization_kwargs must be " + f"deserializable from a JSON string to a dictionary. ") + return loaded + + class MetaTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): @@ -137,101 +147,45 @@ class _NoInitOrTensorImpl: @dataclass -class TensorizerConfig: - tensorizer_uri: Union[str, None] = None - vllm_tensorized: Optional[bool] = False - verify_hash: Optional[bool] = False +class TensorizerConfig(MutableMapping): + tensorizer_uri: Optional[str] = None + tensorizer_dir: Optional[str] = None + vllm_tensorized: Optional[bool] = None + verify_hash: Optional[bool] = None num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None s3_access_key_id: Optional[str] = None s3_secret_access_key: Optional[str] = None s3_endpoint: Optional[str] = None - model_class: Optional[type[torch.nn.Module]] = None - hf_config: Optional[PretrainedConfig] = None - dtype: Optional[Union[str, torch.dtype]] = None lora_dir: Optional[str] = None - _is_sharded: bool = False - - def __post_init__(self): - # check if the configuration is for a sharded vLLM model - self._is_sharded = isinstance(self.tensorizer_uri, str) \ - and re.search(r'%0\dd', self.tensorizer_uri) is not None - if not self.tensorizer_uri and not self.lora_dir: - raise ValueError("tensorizer_uri must be provided.") - if not self.tensorizer_uri and self.lora_dir: - self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors" - assert self.tensorizer_uri is not None, ("tensorizer_uri must be " - "provided.") - self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) - self.lora_dir = self.tensorizer_dir - - @classmethod - def as_dict(cls, *args, **kwargs) -> dict[str, Any]: - cfg = TensorizerConfig(*args, **kwargs) - return dataclasses.asdict(cfg) - - def to_dict(self) -> dict[str, Any]: - return dataclasses.asdict(self) - - def _construct_tensorizer_args(self) -> "TensorizerArgs": - tensorizer_args = { - "tensorizer_uri": self.tensorizer_uri, - "vllm_tensorized": self.vllm_tensorized, - "verify_hash": self.verify_hash, - "num_readers": self.num_readers, - "encryption_keyfile": self.encryption_keyfile, - "s3_access_key_id": self.s3_access_key_id, - "s3_secret_access_key": self.s3_secret_access_key, - "s3_endpoint": self.s3_endpoint, - } - return TensorizerArgs(**tensorizer_args) # type: ignore - - def verify_with_parallel_config( - self, - parallel_config: "ParallelConfig", - ) -> None: - if parallel_config.tensor_parallel_size > 1 \ - and not self._is_sharded: - raise ValueError( - "For a sharded model, tensorizer_uri should include a" - " string format template like '%04d' to be formatted" - " with the rank of the shard") - - def verify_with_model_config(self, model_config: "ModelConfig") -> None: - if (model_config.quantization is not None - and self.tensorizer_uri is not None): - logger.warning( - "Loading a model using Tensorizer with quantization on vLLM" - " is unstable and may lead to errors.") - - def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): - if tensorizer_args is None: - tensorizer_args = self._construct_tensorizer_args() - - return open_stream(self.tensorizer_uri, - **tensorizer_args.stream_params) - - -@dataclass -class TensorizerArgs: - tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, - bytes, os.PathLike, int] - vllm_tensorized: Optional[bool] = False - verify_hash: Optional[bool] = False - num_readers: Optional[int] = None - encryption_keyfile: Optional[str] = None - s3_access_key_id: Optional[str] = None - s3_secret_access_key: Optional[str] = None - s3_endpoint: Optional[str] = None + stream_kwargs: Optional[dict[str, Any]] = None + serialization_kwargs: Optional[dict[str, Any]] = None + deserialization_kwargs: Optional[dict[str, Any]] = None + _extra_serialization_attrs: Optional[dict[str, Any]] = field(init=False, + default=None) + model_class: Optional[type[torch.nn.Module]] = field(init=False, + default=None) + hf_config: Optional[PretrainedConfig] = field(init=False, default=None) + dtype: Optional[Union[str, torch.dtype]] = field(init=False, default=None) + _is_sharded: bool = field(init=False, default=False) + _fields: ClassVar[tuple[str, ...]] + _keys: ClassVar[frozenset[str]] """ - Args for the TensorizerAgent class. These are used to configure the behavior - of the TensorDeserializer when loading tensors from a serialized model. - - Args: + Args for the TensorizerConfig class. These are used to configure the + behavior of model serialization and deserialization using Tensorizer. + + Args: tensorizer_uri: Path to serialized model tensors. Can be a local file path or a S3 URI. This is a required field unless lora_dir is provided and the config is meant to be used for the - `tensorize_lora_adapter` function. + `tensorize_lora_adapter` function. Unless a `tensorizer_dir` or + `lora_dir` is passed to this object's initializer, this is a required + argument. + tensorizer_dir: Path to a directory containing serialized model tensors, + and all other potential model artifacts to load the model, such as + configs and tokenizer files. Can be passed instead of `tensorizer_uri` + where the `model.tensors` file will be assumed to be in this + directory. vllm_tensorized: If True, indicates that the serialized model is a vLLM model. This is used to determine the behavior of the TensorDeserializer when loading tensors from a serialized model. @@ -256,34 +210,174 @@ class TensorizerArgs: be set via the S3_SECRET_ACCESS_KEY environment variable. s3_endpoint: The endpoint for the S3 bucket. Can also be set via the S3_ENDPOINT_URL environment variable. + lora_dir: Path to a directory containing LoRA adapter artifacts for + serialization or deserialization. When serializing LoRA adapters + this is the only necessary parameter to pass to this object's + initializer. """ def __post_init__(self): - self.file_obj = self.tensorizer_uri - self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID - self.s3_secret_access_key = (self.s3_secret_access_key + # check if the configuration is for a sharded vLLM model + self._is_sharded = isinstance(self.tensorizer_uri, str) \ + and re.search(r'%0\dd', self.tensorizer_uri) is not None + + if self.tensorizer_dir and self.tensorizer_uri: + raise ValueError( + "Either tensorizer_dir or tensorizer_uri must be provided, " + "not both.") + if self.tensorizer_dir and self.lora_dir: + raise ValueError( + "Only one of tensorizer_dir or lora_dir may be specified. " + "Use lora_dir exclusively when serializing LoRA adapters, " + "and tensorizer_dir or tensorizer_uri otherwise.") + if not self.tensorizer_uri: + if self.lora_dir: + self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors" + elif self.tensorizer_dir: + self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors" + else: + raise ValueError("Unable to resolve tensorizer_uri. " + "A valid tensorizer_uri or tensorizer_dir " + "must be provided for deserialization, and a " + "valid tensorizer_uri, tensorizer_uri, or " + "lora_dir for serialization.") + else: + self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) + + if not self.serialization_kwargs: + self.serialization_kwargs = {} + if not self.deserialization_kwargs: + self.deserialization_kwargs = {} + + def to_serializable(self) -> dict[str, Any]: + # Due to TensorizerConfig needing to be msgpack-serializable, it needs + # support for morphing back and forth between itself and its dict + # representation + + # TensorizerConfig's representation as a dictionary is meant to be + # linked to TensorizerConfig in such a way that the following is + # technically initializable: + # TensorizerConfig(**my_tensorizer_cfg.to_serializable()) + + # This means the dict must not retain non-initializable parameters + # and post-init attribute states + + # Also don't want to retain private and unset parameters, so only retain + # not None values and public attributes + + raw_tc_dict = asdict(self) + blacklisted = [] + + if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict: + blacklisted.append("tensorizer_dir") + + if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict: + blacklisted.append("tensorizer_dir") + + tc_dict = {} + for k, v in raw_tc_dict.items(): + if (k not in blacklisted and k not in tc_dict + and not k.startswith("_") and v is not None): + tc_dict[k] = v + + return tc_dict + + def _construct_tensorizer_args(self) -> "TensorizerArgs": + return TensorizerArgs(self) # type: ignore + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + if parallel_config.tensor_parallel_size > 1 \ + and not self._is_sharded: + raise ValueError( + "For a sharded model, tensorizer_uri should include a" + " string format template like '%04d' to be formatted" + " with the rank of the shard") + + def verify_with_model_config(self, model_config: "ModelConfig") -> None: + if (model_config.quantization is not None + and self.tensorizer_uri is not None): + logger.warning( + "Loading a model using Tensorizer with quantization on vLLM" + " is unstable and may lead to errors.") + + def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): + if tensorizer_args is None: + tensorizer_args = self._construct_tensorizer_args() + + return open_stream(self.tensorizer_uri, + **tensorizer_args.stream_kwargs) + + def keys(self): + return self._keys + + def __len__(self): + return len(fields(self)) + + def __iter__(self): + return iter(self._fields) + + def __getitem__(self, item: str) -> Any: + if item not in self.keys(): + raise KeyError(item) + return getattr(self, item) + + def __setitem__(self, key: str, value: Any) -> None: + if key not in self.keys(): + # Disallow modifying invalid keys + raise KeyError(key) + setattr(self, key, value) + + def __delitem__(self, key, /): + if key not in self.keys(): + raise KeyError(key) + delattr(self, key) + + +TensorizerConfig._fields = tuple(f.name for f in fields(TensorizerConfig)) +TensorizerConfig._keys = frozenset(TensorizerConfig._fields) + + +@dataclass +class TensorizerArgs: + tensorizer_uri: Optional[str] = None + tensorizer_dir: Optional[str] = None + encryption_keyfile: Optional[str] = None + + def __init__(self, tensorizer_config: TensorizerConfig): + for k, v in tensorizer_config.items(): + setattr(self, k, v) + self.file_obj = tensorizer_config.tensorizer_uri + self.s3_access_key_id = (tensorizer_config.s3_access_key_id + or envs.S3_ACCESS_KEY_ID) + self.s3_secret_access_key = (tensorizer_config.s3_secret_access_key or envs.S3_SECRET_ACCESS_KEY) - self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL - self.stream_params = { - "s3_access_key_id": self.s3_access_key_id, - "s3_secret_access_key": self.s3_secret_access_key, - "s3_endpoint": self.s3_endpoint, + self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL + + self.stream_kwargs = { + "s3_access_key_id": tensorizer_config.s3_access_key_id, + "s3_secret_access_key": tensorizer_config.s3_secret_access_key, + "s3_endpoint": tensorizer_config.s3_endpoint, + **(tensorizer_config.stream_kwargs or {}) } - self.deserializer_params = { - "verify_hash": self.verify_hash, - "encryption": self.encryption_keyfile, - "num_readers": self.num_readers + self.deserialization_kwargs = { + "verify_hash": tensorizer_config.verify_hash, + "encryption": tensorizer_config.encryption_keyfile, + "num_readers": tensorizer_config.num_readers, + **(tensorizer_config.deserialization_kwargs or {}) } if self.encryption_keyfile: with open_stream( - self.encryption_keyfile, - **self.stream_params, + tensorizer_config.encryption_keyfile, + **self.stream_kwargs, ) as stream: key = stream.read() decryption_params = DecryptionParams.from_key(key) - self.deserializer_params['encryption'] = decryption_params + self.deserialization_kwargs['encryption'] = decryption_params @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -405,15 +499,22 @@ def init_tensorizer_model(tensorizer_config: TensorizerConfig, def deserialize_tensorizer_model(model: nn.Module, tensorizer_config: TensorizerConfig) -> None: tensorizer_args = tensorizer_config._construct_tensorizer_args() + if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri): + raise ValueError( + f"{tensorizer_config.tensorizer_uri} is not a valid " + f"tensorizer URI. Please check that the URI is correct. " + f"It must either point to a local existing file, or have a " + f"S3, HTTP or HTTPS scheme.") before_mem = get_mem_usage() start = time.perf_counter() - with _read_stream( + with open_stream( tensorizer_config.tensorizer_uri, - **tensorizer_args.stream_params) as stream, TensorDeserializer( + mode="rb", + **tensorizer_args.stream_kwargs) as stream, TensorDeserializer( stream, dtype=tensorizer_config.dtype, - device=f'cuda:{torch.cuda.current_device()}', - **tensorizer_args.deserializer_params) as deserializer: + device=torch.device("cuda", torch.cuda.current_device()), + **tensorizer_args.deserialization_kwargs) as deserializer: deserializer.load_into_module(model) end = time.perf_counter() @@ -442,9 +543,9 @@ def tensorizer_weights_iterator( "examples/others/tensorize_vllm_model.py example script " "for serializing vLLM models.") - deserializer_args = tensorizer_args.deserializer_params - stream_params = tensorizer_args.stream_params - stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + deserializer_args = tensorizer_args.deserialization_kwargs + stream_kwargs = tensorizer_args.stream_kwargs + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs) with TensorDeserializer(stream, **deserializer_args, device="cpu") as state: yield from state.items() @@ -465,8 +566,8 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: """ tensorizer_args = tensorizer_config._construct_tensorizer_args() deserializer = TensorDeserializer(open_stream( - tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params), - **tensorizer_args.deserializer_params, + tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs), + **tensorizer_args.deserialization_kwargs, lazy_load=True) if tensorizer_config.vllm_tensorized: logger.warning( @@ -477,13 +578,41 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: return ".vllm_tensorized_marker" in deserializer +def serialize_extra_artifacts( + tensorizer_args: TensorizerArgs, + served_model_name: Union[str, list[str], None]) -> None: + if not isinstance(served_model_name, str): + raise ValueError( + f"served_model_name must be a str for serialize_extra_artifacts, " + f"not {type(served_model_name)}.") + + with tempfile.TemporaryDirectory() as tmpdir: + snapshot_download(served_model_name, + local_dir=tmpdir, + ignore_patterns=[ + "*.pt", "*.safetensors", "*.bin", "*.cache", + "*.gitattributes", "*.md" + ]) + for artifact in os.scandir(tmpdir): + if not artifact.is_file(): + continue + with open(artifact.path, "rb") as f, open_stream( + f"{tensorizer_args.tensorizer_dir}/{artifact.name}", + mode="wb+", + **tensorizer_args.stream_kwargs) as stream: + logger.info("Writing artifact %s", artifact.name) + stream.write(f.read()) + + def serialize_vllm_model( model: nn.Module, tensorizer_config: TensorizerConfig, + model_config: "ModelConfig", ) -> nn.Module: model.register_parameter( "vllm_tensorized_marker", nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + tensorizer_args = tensorizer_config._construct_tensorizer_args() encryption_params = None @@ -497,10 +626,16 @@ def serialize_vllm_model( from vllm.distributed import get_tensor_model_parallel_rank output_file = output_file % get_tensor_model_parallel_rank() - with _write_stream(output_file, **tensorizer_args.stream_params) as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) + with open_stream(output_file, mode="wb+", + **tensorizer_args.stream_kwargs) as stream: + serializer = TensorSerializer(stream, + encryption=encryption_params, + **tensorizer_config.serialization_kwargs) serializer.write_module(model) serializer.close() + + serialize_extra_artifacts(tensorizer_args, model_config.served_model_name) + logger.info("Successfully serialized model to %s", str(output_file)) return model @@ -522,8 +657,9 @@ def tensorize_vllm_model(engine_args: "EngineArgs", if generate_keyfile and (keyfile := tensorizer_config.encryption_keyfile) is not None: encryption_params = EncryptionParams.random() - with _write_stream( + with open_stream( keyfile, + mode="wb+", s3_access_key_id=tensorizer_config.s3_access_key_id, s3_secret_access_key=tensorizer_config.s3_secret_access_key, s3_endpoint=tensorizer_config.s3_endpoint, @@ -537,13 +673,13 @@ def tensorize_vllm_model(engine_args: "EngineArgs", engine = LLMEngine.from_engine_args(engine_args) engine.model_executor.collective_rpc( "save_tensorized_model", - kwargs=dict(tensorizer_config=tensorizer_config), + kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, ) else: engine = V1LLMEngine.from_vllm_config(engine_config) engine.collective_rpc( "save_tensorized_model", - kwargs=dict(tensorizer_config=tensorizer_config), + kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, ) @@ -586,14 +722,14 @@ def tensorize_lora_adapter(lora_path: str, with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json", mode="wb+", - **tensorizer_args.stream_params) as f: + **tensorizer_args.stream_kwargs) as f: f.write(json.dumps(config).encode("utf-8")) lora_uri = (f"{tensorizer_config.lora_dir}" f"/adapter_model.tensors") with open_stream(lora_uri, mode="wb+", - **tensorizer_args.stream_params) as f: + **tensorizer_args.stream_kwargs) as f: serializer = TensorSerializer(f) serializer.write_state_dict(tensors) serializer.close() diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 0b62e744e..9ecc31893 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -20,6 +20,18 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture, logger = init_logger(__name__) +BLACKLISTED_TENSORIZER_ARGS = { + "device", # vLLM decides this + "dtype", # vLLM decides this + "mode", # Not meant to be configurable by the user +} + + +def validate_config(config: dict): + for k, v in config.items(): + if v is not None and k in BLACKLISTED_TENSORIZER_ARGS: + raise ValueError(f"{k} is not an allowed Tensorizer argument.") + class TensorizerLoader(BaseModelLoader): """Model loader using CoreWeave's tensorizer library.""" @@ -29,6 +41,7 @@ class TensorizerLoader(BaseModelLoader): if isinstance(load_config.model_loader_extra_config, TensorizerConfig): self.tensorizer_config = load_config.model_loader_extra_config else: + validate_config(load_config.model_loader_extra_config) self.tensorizer_config = TensorizerConfig( **load_config.model_loader_extra_config) @@ -118,10 +131,12 @@ class TensorizerLoader(BaseModelLoader): def save_model( model: torch.nn.Module, tensorizer_config: Union[TensorizerConfig, dict], + model_config: ModelConfig, ) -> None: if isinstance(tensorizer_config, dict): tensorizer_config = TensorizerConfig(**tensorizer_config) serialize_vllm_model( model=model, tensorizer_config=tensorizer_config, + model_config=model_config, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5a26e88db..8658d7d91 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1820,6 +1820,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): TensorizerLoader.save_model( self.model, tensorizer_config=tensorizer_config, + model_config=self.model_config, ) def _get_prompt_logprobs_dict( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 82db6617b..9d936f3db 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1246,6 +1246,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): TensorizerLoader.save_model( self.model, tensorizer_config=tensorizer_config, + model_config=self.model_config, ) def get_max_block_per_batch(self) -> int: