[Frontend] [Core] Add Tensorizer support for V1, LoRA adapter serialization and deserialization (#17926)

Signed-off-by: Sanger Steel <sangersteel@gmail.com>
This commit is contained in:
Sanger Steel
2025-05-22 21:44:18 -04:00
committed by GitHub
parent c91fe7b1b9
commit c32e249a23
16 changed files with 606 additions and 197 deletions

View File

@@ -6,11 +6,12 @@ import json
import os
import uuid
from vllm import LLM
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
TensorizerConfig,
tensorize_vllm_model)
from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader.tensorizer import (
TensorizerArgs, TensorizerConfig, tensorize_lora_adapter,
tensorize_vllm_model)
from vllm.utils import FlexibleArgumentParser
# yapf conflicts with isort for this docstring
@@ -27,7 +28,7 @@ https://github.com/coreweave/tensorizer
To serialize a model, install vLLM from source, then run something
like this from the root level of this repository:
python -m examples.other.tensorize_vllm_model \
python examples/other/tensorize_vllm_model.py \
--model facebook/opt-125m \
serialize \
--serialized-directory s3://my-bucket \
@@ -47,7 +48,7 @@ providing a `--keyfile` argument.
To deserialize a model, you can run something like this from the root
level of this repository:
python -m examples.other.tensorize_vllm_model \
python examples/other/tensorize_vllm_model.py \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
deserialize \
@@ -69,7 +70,7 @@ For more information on the available arguments for serializing, run
Or for deserializing:
`python -m examples.other.tensorize_vllm_model deserialize --help`.
`python examples/other/tensorize_vllm_model.py deserialize --help`.
Once a model is serialized, tensorizer can be invoked with the `LLM` class
directly to load models:
@@ -90,11 +91,27 @@ TensorizerConfig arguments desired.
In order to see all of the available arguments usable to configure
loading with tensorizer that are given to `TensorizerConfig`, run:
`python -m examples.other.tensorize_vllm_model deserialize --help`
`python examples/other/tensorize_vllm_model.py deserialize --help`
under the `tensorizer options` section. These can also be used for
deserialization in this example script, although `--tensorizer-uri` and
`--path-to-tensors` are functionally the same in this case.
Tensorizer can also be used to save and load LoRA adapters. A LoRA adapter
can be serialized directly with the path to the LoRA adapter on HF Hub and
a TensorizerConfig object. In this script, passing a HF id to a LoRA adapter
will serialize the LoRA adapter artifacts to `--serialized-directory`.
You can then use the LoRA adapter with `vllm serve`, for instance, by ensuring
the LoRA artifacts are in your model artifacts directory and specifying
`--enable-lora`. For instance:
```
vllm serve <model_path> \
--load-format tensorizer \
--model-loader-extra-config '{"tensorizer_uri": "<model_path>.tensors"}' \
--enable-lora
```
"""
@@ -107,6 +124,19 @@ def parse_args():
"also supported, although libsodium must be installed to "
"use it.")
parser = EngineArgs.add_cli_args(parser)
parser.add_argument(
"--lora-path",
type=str,
required=False,
help="Path to a LoRA adapter to "
"serialize along with model tensors. This can then be deserialized "
"along with the model by passing a tensorizer_config kwarg to "
"LoRARequest with type TensorizerConfig. See the docstring for this "
"for a usage example."
)
subparsers = parser.add_subparsers(dest='command')
serialize_parser = subparsers.add_parser(
@@ -169,11 +199,42 @@ def parse_args():
def deserialize():
llm = LLM(model=args.model,
load_format="tensorizer",
tensor_parallel_size=args.tensor_parallel_size,
model_loader_extra_config=tensorizer_config
)
if args.lora_path:
tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
llm = LLM(model=args.model,
load_format="tensorizer",
tensor_parallel_size=args.tensor_parallel_size,
model_loader_extra_config=tensorizer_config,
enable_lora=True,
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=256,
stop=["[/assistant]"]
)
# Truncating this as the extra text isn't necessary
prompts = [
"[user] Write a SQL query to answer the question based on ..."
]
# Test LoRA load
print(
llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest("sql-lora",
1,
args.lora_path,
tensorizer_config = tensorizer_config)
)
)
else:
llm = LLM(model=args.model,
load_format="tensorizer",
tensor_parallel_size=args.tensor_parallel_size,
model_loader_extra_config=tensorizer_config
)
return llm
@@ -197,7 +258,10 @@ if __name__ == '__main__':
model_name = model_ref.split("/")[1]
keyfile = args.keyfile if args.keyfile else None
if args.command == "serialize" or args.command == "deserialize":
keyfile = args.keyfile
else:
keyfile = None
if args.model_loader_extra_config:
config = json.loads(args.model_loader_extra_config)
@@ -228,6 +292,10 @@ if __name__ == '__main__':
encryption_keyfile=keyfile,
**credentials)
if args.lora_path:
tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
tensorize_lora_adapter(args.lora_path, tensorizer_config)
tensorize_vllm_model(engine_args, tensorizer_config)
elif args.command == "deserialize":