[Frontend] [Core] Add Tensorizer support for V1, LoRA adapter serialization and deserialization (#17926)

Signed-off-by: Sanger Steel <sangersteel@gmail.com>
This commit is contained in:
Sanger Steel
2025-05-22 21:44:18 -04:00
committed by GitHub
parent c91fe7b1b9
commit c32e249a23
16 changed files with 606 additions and 197 deletions

View File

@@ -1,24 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import contextlib
import contextvars
import dataclasses
import io
import json
import os
import re
import threading
import time
from collections.abc import Generator
from dataclasses import dataclass
from functools import partial
from typing import BinaryIO, Optional, Union
from typing import Any, BinaryIO, Optional, Union
import torch
from torch import nn
from torch.utils._python_dispatch import TorchDispatchMode
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -58,9 +62,79 @@ __all__ = [
logger = init_logger(__name__)
class MetaTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func._schema.name == "aten::empty" and "device" not in kwargs:
kwargs["device"] = "meta"
return func(*args, **kwargs)
def meta_tensor_mode(loading_code=None, ):
if loading_code is None:
return _NoInitOrTensorImpl.context_manager()
elif callable(loading_code):
with _NoInitOrTensorImpl.context_manager():
return loading_code()
else:
raise TypeError(
"expected a callable to evaluate,"
" or None if being used as a context manager;"
f' got an object of type "{type(loading_code).__name__}" instead.')
class _NoInitOrTensorImpl:
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active",
default=False)
_count_active: int = 0
_count_active_lock = threading.Lock()
@classmethod
@contextlib.contextmanager
def context_manager(cls):
if cls.is_active.get():
yield
return
with cls._count_active_lock:
cls._count_active += 1
if cls._count_active == 1:
for mod in cls._MODULES:
mod.reset_parameters = cls._disable(mod.reset_parameters)
reset_token = cls.is_active.set(True)
try:
with MetaTensorMode():
yield
finally:
cls.is_active.reset(reset_token)
with cls._count_active_lock:
cls._count_active -= 1
if cls._count_active == 0:
for mod, original in cls._MODULE_ORIGINALS:
mod.reset_parameters = original
@staticmethod
def _disable(func):
def wrapper(*args, **kwargs):
if not _NoInitOrTensorImpl.is_active.get():
return func(*args, **kwargs)
return wrapper
@dataclass
class TensorizerConfig:
tensorizer_uri: str
tensorizer_uri: Union[str, None] = None
vllm_tensorized: Optional[bool] = False
verify_hash: Optional[bool] = False
num_readers: Optional[int] = None
@@ -71,12 +145,29 @@ class TensorizerConfig:
model_class: Optional[type[torch.nn.Module]] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Optional[Union[str, torch.dtype]] = None
lora_dir: Optional[str] = None
_is_sharded: bool = False
def __post_init__(self):
# check if the configuration is for a sharded vLLM model
self._is_sharded = isinstance(self.tensorizer_uri, str) \
and re.search(r'%0\dd', self.tensorizer_uri) is not None
if not self.tensorizer_uri and not self.lora_dir:
raise ValueError("tensorizer_uri must be provided.")
if not self.tensorizer_uri and self.lora_dir:
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
assert self.tensorizer_uri is not None, ("tensorizer_uri must be "
"provided.")
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
self.lora_dir = self.tensorizer_dir
@classmethod
def as_dict(cls, *args, **kwargs) -> dict[str, Any]:
cfg = TensorizerConfig(*args, **kwargs)
return dataclasses.asdict(cfg)
def to_dict(self) -> dict[str, Any]:
return dataclasses.asdict(self)
def _construct_tensorizer_args(self) -> "TensorizerArgs":
tensorizer_args = {
@@ -140,7 +231,9 @@ class TensorizerArgs:
Args:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI.
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the
`tensorize_lora_adapter` function.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
@@ -296,10 +389,10 @@ class TensorizerAgent:
model_args.torch_dtype = self.tensorizer_config.dtype
assert self.tensorizer_config.model_class is not None
# TODO: Do we need to consider old-style model class?
with no_init_or_tensor(), set_current_vllm_config(self.vllm_config,
check_compile=True):
with meta_tensor_mode(), set_current_vllm_config(self.vllm_config,
check_compile=True):
return self.tensorizer_config.model_class(
vllm_config=self.vllm_config, )
vllm_config=self.vllm_config)
def _resize_lora_embeddings(self):
"""Modify LoRA embedding layers to use bigger tensors
@@ -467,8 +560,73 @@ def tensorize_vllm_model(engine_args: EngineArgs,
) as stream:
stream.write(encryption_params.key)
engine = LLMEngine.from_engine_args(engine_args)
engine.model_executor.collective_rpc(
"save_tensorized_model",
kwargs=dict(tensorizer_config=tensorizer_config),
)
from vllm import LLMEngine
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
if not envs.VLLM_USE_V1:
engine = LLMEngine.from_engine_args(engine_args)
engine.model_executor.collective_rpc(
"save_tensorized_model",
kwargs=dict(tensorizer_config=tensorizer_config),
)
else:
engine = V1LLMEngine.from_vllm_config(engine_config)
engine.collective_rpc(
"save_tensorized_model",
kwargs=dict(tensorizer_config=tensorizer_config),
)
def tensorize_lora_adapter(lora_path: str,
tensorizer_config: TensorizerConfig):
"""
Uses tensorizer to serialize a LoRA adapter. Assumes that the files
needed to load a LoRA adapter are a safetensors-format file called
adapter_model.safetensors and a json config file called adapter_config.json.
Serializes the files in the tensorizer_config.lora_dir
"""
import safetensors
from vllm.lora.utils import get_adapter_absolute_path
lora_dir = get_adapter_absolute_path(lora_path)
tensor_path = config_path = ""
for file in os.listdir(lora_dir):
if file.startswith("adapter_model"):
tensor_path = lora_dir + "/" + file
if file.startswith("adapter_config"):
config_path = lora_dir + "/" + file
if tensor_path and config_path:
break
if tensor_path.endswith(".safetensors"):
tensors = safetensors.torch.load_file(tensor_path)
elif tensor_path.endswith(".bin"):
tensors = torch.load(tensor_path)
else:
raise ValueError("Unsupported file: %s", tensor_path)
with open(config_path) as f:
config = json.load(f)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json",
mode="wb+",
**tensorizer_args.stream_params) as f:
f.write(json.dumps(config).encode("utf-8"))
lora_uri = (f"{tensorizer_config.lora_dir}"
f"/adapter_model.tensors")
with open_stream(lora_uri, mode="wb+",
**tensorizer_args.stream_params) as f:
serializer = TensorSerializer(f)
serializer.write_state_dict(tensors)
serializer.close()
logger.info("Successfully serialized LoRA files to %s",
str(tensorizer_config.lora_dir))

View File

@@ -2,6 +2,7 @@
# ruff: noqa: SIM117
import copy
from collections.abc import Generator
from typing import Union
import torch
from torch import nn
@@ -111,8 +112,10 @@ class TensorizerLoader(BaseModelLoader):
@staticmethod
def save_model(
model: torch.nn.Module,
tensorizer_config: TensorizerConfig,
tensorizer_config: Union[TensorizerConfig, dict],
) -> None:
if isinstance(tensorizer_config, dict):
tensorizer_config = TensorizerConfig(**tensorizer_config)
serialize_vllm_model(
model=model,
tensorizer_config=tensorizer_config,