[Frontend] [Core] Add Tensorizer support for V1, LoRA adapter serialization and deserialization (#17926)
Signed-off-by: Sanger Steel <sangersteel@gmail.com>
This commit is contained in:
@@ -1,24 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import contextvars
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import BinaryIO, Optional, Union
|
||||
from typing import Any, BinaryIO, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@@ -58,9 +62,79 @@ __all__ = [
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MetaTensorMode(TorchDispatchMode):
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if func._schema.name == "aten::empty" and "device" not in kwargs:
|
||||
kwargs["device"] = "meta"
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def meta_tensor_mode(loading_code=None, ):
|
||||
|
||||
if loading_code is None:
|
||||
return _NoInitOrTensorImpl.context_manager()
|
||||
elif callable(loading_code):
|
||||
with _NoInitOrTensorImpl.context_manager():
|
||||
return loading_code()
|
||||
else:
|
||||
raise TypeError(
|
||||
"expected a callable to evaluate,"
|
||||
" or None if being used as a context manager;"
|
||||
f' got an object of type "{type(loading_code).__name__}" instead.')
|
||||
|
||||
|
||||
class _NoInitOrTensorImpl:
|
||||
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
|
||||
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
|
||||
|
||||
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active",
|
||||
default=False)
|
||||
_count_active: int = 0
|
||||
_count_active_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def context_manager(cls):
|
||||
if cls.is_active.get():
|
||||
yield
|
||||
return
|
||||
|
||||
with cls._count_active_lock:
|
||||
cls._count_active += 1
|
||||
if cls._count_active == 1:
|
||||
for mod in cls._MODULES:
|
||||
mod.reset_parameters = cls._disable(mod.reset_parameters)
|
||||
|
||||
reset_token = cls.is_active.set(True)
|
||||
|
||||
try:
|
||||
with MetaTensorMode():
|
||||
yield
|
||||
finally:
|
||||
cls.is_active.reset(reset_token)
|
||||
with cls._count_active_lock:
|
||||
cls._count_active -= 1
|
||||
if cls._count_active == 0:
|
||||
for mod, original in cls._MODULE_ORIGINALS:
|
||||
mod.reset_parameters = original
|
||||
|
||||
@staticmethod
|
||||
def _disable(func):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _NoInitOrTensorImpl.is_active.get():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerConfig:
|
||||
tensorizer_uri: str
|
||||
tensorizer_uri: Union[str, None] = None
|
||||
vllm_tensorized: Optional[bool] = False
|
||||
verify_hash: Optional[bool] = False
|
||||
num_readers: Optional[int] = None
|
||||
@@ -71,12 +145,29 @@ class TensorizerConfig:
|
||||
model_class: Optional[type[torch.nn.Module]] = None
|
||||
hf_config: Optional[PretrainedConfig] = None
|
||||
dtype: Optional[Union[str, torch.dtype]] = None
|
||||
lora_dir: Optional[str] = None
|
||||
_is_sharded: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# check if the configuration is for a sharded vLLM model
|
||||
self._is_sharded = isinstance(self.tensorizer_uri, str) \
|
||||
and re.search(r'%0\dd', self.tensorizer_uri) is not None
|
||||
if not self.tensorizer_uri and not self.lora_dir:
|
||||
raise ValueError("tensorizer_uri must be provided.")
|
||||
if not self.tensorizer_uri and self.lora_dir:
|
||||
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
|
||||
assert self.tensorizer_uri is not None, ("tensorizer_uri must be "
|
||||
"provided.")
|
||||
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
|
||||
self.lora_dir = self.tensorizer_dir
|
||||
|
||||
@classmethod
|
||||
def as_dict(cls, *args, **kwargs) -> dict[str, Any]:
|
||||
cfg = TensorizerConfig(*args, **kwargs)
|
||||
return dataclasses.asdict(cfg)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||
tensorizer_args = {
|
||||
@@ -140,7 +231,9 @@ class TensorizerArgs:
|
||||
|
||||
Args:
|
||||
tensorizer_uri: Path to serialized model tensors. Can be a local file
|
||||
path or a S3 URI.
|
||||
path or a S3 URI. This is a required field unless lora_dir is
|
||||
provided and the config is meant to be used for the
|
||||
`tensorize_lora_adapter` function.
|
||||
vllm_tensorized: If True, indicates that the serialized model is a
|
||||
vLLM model. This is used to determine the behavior of the
|
||||
TensorDeserializer when loading tensors from a serialized model.
|
||||
@@ -296,10 +389,10 @@ class TensorizerAgent:
|
||||
model_args.torch_dtype = self.tensorizer_config.dtype
|
||||
assert self.tensorizer_config.model_class is not None
|
||||
# TODO: Do we need to consider old-style model class?
|
||||
with no_init_or_tensor(), set_current_vllm_config(self.vllm_config,
|
||||
check_compile=True):
|
||||
with meta_tensor_mode(), set_current_vllm_config(self.vllm_config,
|
||||
check_compile=True):
|
||||
return self.tensorizer_config.model_class(
|
||||
vllm_config=self.vllm_config, )
|
||||
vllm_config=self.vllm_config)
|
||||
|
||||
def _resize_lora_embeddings(self):
|
||||
"""Modify LoRA embedding layers to use bigger tensors
|
||||
@@ -467,8 +560,73 @@ def tensorize_vllm_model(engine_args: EngineArgs,
|
||||
) as stream:
|
||||
stream.write(encryption_params.key)
|
||||
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
engine.model_executor.collective_rpc(
|
||||
"save_tensorized_model",
|
||||
kwargs=dict(tensorizer_config=tensorizer_config),
|
||||
)
|
||||
from vllm import LLMEngine
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
engine.model_executor.collective_rpc(
|
||||
"save_tensorized_model",
|
||||
kwargs=dict(tensorizer_config=tensorizer_config),
|
||||
)
|
||||
else:
|
||||
engine = V1LLMEngine.from_vllm_config(engine_config)
|
||||
engine.collective_rpc(
|
||||
"save_tensorized_model",
|
||||
kwargs=dict(tensorizer_config=tensorizer_config),
|
||||
)
|
||||
|
||||
|
||||
def tensorize_lora_adapter(lora_path: str,
|
||||
tensorizer_config: TensorizerConfig):
|
||||
"""
|
||||
Uses tensorizer to serialize a LoRA adapter. Assumes that the files
|
||||
needed to load a LoRA adapter are a safetensors-format file called
|
||||
adapter_model.safetensors and a json config file called adapter_config.json.
|
||||
|
||||
Serializes the files in the tensorizer_config.lora_dir
|
||||
"""
|
||||
import safetensors
|
||||
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
|
||||
lora_dir = get_adapter_absolute_path(lora_path)
|
||||
|
||||
tensor_path = config_path = ""
|
||||
|
||||
for file in os.listdir(lora_dir):
|
||||
if file.startswith("adapter_model"):
|
||||
tensor_path = lora_dir + "/" + file
|
||||
if file.startswith("adapter_config"):
|
||||
config_path = lora_dir + "/" + file
|
||||
if tensor_path and config_path:
|
||||
break
|
||||
|
||||
if tensor_path.endswith(".safetensors"):
|
||||
tensors = safetensors.torch.load_file(tensor_path)
|
||||
elif tensor_path.endswith(".bin"):
|
||||
tensors = torch.load(tensor_path)
|
||||
else:
|
||||
raise ValueError("Unsupported file: %s", tensor_path)
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
|
||||
with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_params) as f:
|
||||
|
||||
f.write(json.dumps(config).encode("utf-8"))
|
||||
|
||||
lora_uri = (f"{tensorizer_config.lora_dir}"
|
||||
f"/adapter_model.tensors")
|
||||
with open_stream(lora_uri, mode="wb+",
|
||||
**tensorizer_args.stream_params) as f:
|
||||
serializer = TensorSerializer(f)
|
||||
serializer.write_state_dict(tensors)
|
||||
serializer.close()
|
||||
|
||||
logger.info("Successfully serialized LoRA files to %s",
|
||||
str(tensorizer_config.lora_dir))
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# ruff: noqa: SIM117
|
||||
import copy
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -111,8 +112,10 @@ class TensorizerLoader(BaseModelLoader):
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
tensorizer_config: Union[TensorizerConfig, dict],
|
||||
) -> None:
|
||||
if isinstance(tensorizer_config, dict):
|
||||
tensorizer_config = TensorizerConfig(**tensorizer_config)
|
||||
serialize_vllm_model(
|
||||
model=model,
|
||||
tensorizer_config=tensorizer_config,
|
||||
|
||||
Reference in New Issue
Block a user