[Frontend] [Core] Add Tensorizer support for V1, LoRA adapter serialization and deserialization (#17926)
Signed-off-by: Sanger Steel <sangersteel@gmail.com>
This commit is contained in:
@@ -6,11 +6,12 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from vllm import LLM
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
|
||||
TensorizerConfig,
|
||||
tensorize_vllm_model)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerArgs, TensorizerConfig, tensorize_lora_adapter,
|
||||
tensorize_vllm_model)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
@@ -27,7 +28,7 @@ https://github.com/coreweave/tensorizer
|
||||
To serialize a model, install vLLM from source, then run something
|
||||
like this from the root level of this repository:
|
||||
|
||||
python -m examples.other.tensorize_vllm_model \
|
||||
python examples/other/tensorize_vllm_model.py \
|
||||
--model facebook/opt-125m \
|
||||
serialize \
|
||||
--serialized-directory s3://my-bucket \
|
||||
@@ -47,7 +48,7 @@ providing a `--keyfile` argument.
|
||||
To deserialize a model, you can run something like this from the root
|
||||
level of this repository:
|
||||
|
||||
python -m examples.other.tensorize_vllm_model \
|
||||
python examples/other/tensorize_vllm_model.py \
|
||||
--model EleutherAI/gpt-j-6B \
|
||||
--dtype float16 \
|
||||
deserialize \
|
||||
@@ -69,7 +70,7 @@ For more information on the available arguments for serializing, run
|
||||
|
||||
Or for deserializing:
|
||||
|
||||
`python -m examples.other.tensorize_vllm_model deserialize --help`.
|
||||
`python examples/other/tensorize_vllm_model.py deserialize --help`.
|
||||
|
||||
Once a model is serialized, tensorizer can be invoked with the `LLM` class
|
||||
directly to load models:
|
||||
@@ -90,11 +91,27 @@ TensorizerConfig arguments desired.
|
||||
In order to see all of the available arguments usable to configure
|
||||
loading with tensorizer that are given to `TensorizerConfig`, run:
|
||||
|
||||
`python -m examples.other.tensorize_vllm_model deserialize --help`
|
||||
`python examples/other/tensorize_vllm_model.py deserialize --help`
|
||||
|
||||
under the `tensorizer options` section. These can also be used for
|
||||
deserialization in this example script, although `--tensorizer-uri` and
|
||||
`--path-to-tensors` are functionally the same in this case.
|
||||
|
||||
Tensorizer can also be used to save and load LoRA adapters. A LoRA adapter
|
||||
can be serialized directly with the path to the LoRA adapter on HF Hub and
|
||||
a TensorizerConfig object. In this script, passing a HF id to a LoRA adapter
|
||||
will serialize the LoRA adapter artifacts to `--serialized-directory`.
|
||||
|
||||
You can then use the LoRA adapter with `vllm serve`, for instance, by ensuring
|
||||
the LoRA artifacts are in your model artifacts directory and specifying
|
||||
`--enable-lora`. For instance:
|
||||
|
||||
```
|
||||
vllm serve <model_path> \
|
||||
--load-format tensorizer \
|
||||
--model-loader-extra-config '{"tensorizer_uri": "<model_path>.tensors"}' \
|
||||
--enable-lora
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -107,6 +124,19 @@ def parse_args():
|
||||
"also supported, although libsodium must be installed to "
|
||||
"use it.")
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to a LoRA adapter to "
|
||||
"serialize along with model tensors. This can then be deserialized "
|
||||
"along with the model by passing a tensorizer_config kwarg to "
|
||||
"LoRARequest with type TensorizerConfig. See the docstring for this "
|
||||
"for a usage example."
|
||||
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command')
|
||||
|
||||
serialize_parser = subparsers.add_parser(
|
||||
@@ -169,11 +199,42 @@ def parse_args():
|
||||
|
||||
|
||||
def deserialize():
|
||||
llm = LLM(model=args.model,
|
||||
load_format="tensorizer",
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
model_loader_extra_config=tensorizer_config
|
||||
)
|
||||
if args.lora_path:
|
||||
tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
|
||||
llm = LLM(model=args.model,
|
||||
load_format="tensorizer",
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
model_loader_extra_config=tensorizer_config,
|
||||
enable_lora=True,
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=256,
|
||||
stop=["[/assistant]"]
|
||||
)
|
||||
|
||||
# Truncating this as the extra text isn't necessary
|
||||
prompts = [
|
||||
"[user] Write a SQL query to answer the question based on ..."
|
||||
]
|
||||
|
||||
# Test LoRA load
|
||||
print(
|
||||
llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest("sql-lora",
|
||||
1,
|
||||
args.lora_path,
|
||||
tensorizer_config = tensorizer_config)
|
||||
)
|
||||
)
|
||||
else:
|
||||
llm = LLM(model=args.model,
|
||||
load_format="tensorizer",
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
model_loader_extra_config=tensorizer_config
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
@@ -197,7 +258,10 @@ if __name__ == '__main__':
|
||||
|
||||
model_name = model_ref.split("/")[1]
|
||||
|
||||
keyfile = args.keyfile if args.keyfile else None
|
||||
if args.command == "serialize" or args.command == "deserialize":
|
||||
keyfile = args.keyfile
|
||||
else:
|
||||
keyfile = None
|
||||
|
||||
if args.model_loader_extra_config:
|
||||
config = json.loads(args.model_loader_extra_config)
|
||||
@@ -228,6 +292,10 @@ if __name__ == '__main__':
|
||||
encryption_keyfile=keyfile,
|
||||
**credentials)
|
||||
|
||||
if args.lora_path:
|
||||
tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
|
||||
tensorize_lora_adapter(args.lora_path, tensorizer_config)
|
||||
|
||||
tensorize_vllm_model(engine_args, tensorizer_config)
|
||||
|
||||
elif args.command == "deserialize":
|
||||
|
||||
Reference in New Issue
Block a user