[Transform] [Quantization] Add transforms to compressed tensors (#22486)
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Hashable
|
||||
from fractions import Fraction
|
||||
from typing import Callable, Optional, Union
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import _make_synced_weight_loader
|
||||
|
||||
@@ -27,7 +30,7 @@ class BasevLLMParameter(Parameter):
|
||||
into the parameter when the provided weight loader is called.
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, **kwargs):
|
||||
def __new__(cls, data: Optional[torch.Tensor], **kwargs):
|
||||
|
||||
return super().__new__(cls, data=data, requires_grad=False)
|
||||
|
||||
@@ -81,6 +84,17 @@ class BasevLLMParameter(Parameter):
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
self._assert_and_load(loaded_weight)
|
||||
|
||||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
||||
if isinstance(shard_id, int):
|
||||
return shard_id
|
||||
|
||||
# if not int, assume shard_id for qkv
|
||||
# map to int and return
|
||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||
assert isinstance(shard_id, str)
|
||||
assert shard_id in qkv_idxs
|
||||
return qkv_idxs[shard_id]
|
||||
|
||||
|
||||
class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
"""
|
||||
@@ -113,6 +127,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
|
||||
if isinstance(
|
||||
self,
|
||||
(PackedColumnParameter,
|
||||
@@ -137,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
shard_id = kwargs.get("shard_id")
|
||||
num_heads = kwargs.get("num_heads")
|
||||
|
||||
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
|
||||
if isinstance(
|
||||
self,
|
||||
(PackedColumnParameter,
|
||||
@@ -224,19 +240,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
||||
if isinstance(shard_id, int):
|
||||
return shard_id
|
||||
|
||||
# if not int, assume shard_id for qkv
|
||||
# map to int and return
|
||||
assert isinstance(shard_id, str)
|
||||
assert shard_id in self.qkv_idxs
|
||||
return self.qkv_idxs[shard_id]
|
||||
|
||||
# For row parallel layers, no sharding needed
|
||||
# load weight into parameter as is
|
||||
def load_row_parallel_weight(self, *args, **kwargs):
|
||||
@@ -373,6 +378,141 @@ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
pass
|
||||
|
||||
|
||||
class SharedWeightParameter(BasevLLMParameter):
|
||||
"""
|
||||
Parameter for weights with many shared tensors across a model
|
||||
|
||||
For example, when applying transforms to the "gate" and "up" partitions of
|
||||
`MergedColumnParallelLinear`, the transform weights must stay separate
|
||||
tensors in order to allow for tensor memory sharing between layers.
|
||||
"""
|
||||
# global registry for sharing tensors based on passed `data_key`
|
||||
# this dict holds weaksrefs to avoid memory leak after model cleanup
|
||||
tensors_registry: WeakValueDictionary = WeakValueDictionary()
|
||||
|
||||
# local container for strong references to shared tensors
|
||||
# this set compensates for the fact that torch.nn.Parameter
|
||||
# and Parameter subclasses do not hold reliable references to tensors
|
||||
local_tensors: set[torch.Tensor]
|
||||
|
||||
# dictionary mapping partition indices to associated parameters
|
||||
partitions: dict[int, Union[ModelWeightParameter, Parameter]]
|
||||
|
||||
def __new__(cls, **kwargs):
|
||||
return super().__new__(cls, data=None, **kwargs)
|
||||
|
||||
def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs):
|
||||
weight_loader: Callable = kwargs.get(
|
||||
"weight_loader") # type: ignore[assignment]
|
||||
super().__init__(data=None, weight_loader=weight_loader)
|
||||
|
||||
self.local_tensors = set()
|
||||
self.partitions = {}
|
||||
self.kwargs = {
|
||||
"input_dim": input_dim,
|
||||
"output_dim": output_dim,
|
||||
"weight_loader": self._fake_weight_loader
|
||||
}
|
||||
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
if self.tp_size > 1:
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not "
|
||||
"currently support tensor parallelism")
|
||||
|
||||
def add_partition(self, index: int, data_key: Hashable, *args, **kwargs):
|
||||
"""
|
||||
Add a partition to the weight parameter. Partitions whose `data_key`
|
||||
is the same will share tensor data
|
||||
|
||||
:param index: index of partition to add
|
||||
:param data_key: hashable key used to key shared tensors
|
||||
:param *args: arguments for `torch.empty`
|
||||
:param **kwargs: keyword arguments for `torch.empty`
|
||||
"""
|
||||
# load (shared) tensor using `data_key`
|
||||
if data_key not in self.tensors_registry:
|
||||
data = torch.empty(*args, **kwargs)
|
||||
self.tensors_registry[data_key] = data
|
||||
else:
|
||||
data = self.tensors_registry[data_key]
|
||||
|
||||
# create associated model parameter
|
||||
self.partitions[index] = ModelWeightParameter(
|
||||
data=data, **self.kwargs) # type: ignore[arg-type]
|
||||
|
||||
# hold local reference, since ModelWeightParameter does not
|
||||
# see https://github.com/pytorch/pytorch/issues/75932
|
||||
self.local_tensors.add(data)
|
||||
|
||||
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
assert len(self.partitions) == 1 and 0 in self.partitions
|
||||
partition = self.partitions[0]
|
||||
|
||||
ModelWeightParameter.load_column_parallel_weight(
|
||||
partition, loaded_weight)
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
assert len(self.partitions) == 1 and 0 in self.partitions
|
||||
partition = self.partitions[0]
|
||||
|
||||
ModelWeightParameter.load_row_parallel_weight(partition, loaded_weight)
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
partition_id = kwargs.pop("shard_id")
|
||||
partition_id = self._shard_id_as_int(partition_id)
|
||||
partition = self.partitions[partition_id]
|
||||
|
||||
input_dim = self.kwargs.get("input_dim")
|
||||
shard_size = partition.data.size(input_dim) // self.tp_size
|
||||
shard_offset = self.tp_rank * shard_size
|
||||
|
||||
ModelWeightParameter.load_merged_column_weight(
|
||||
partition,
|
||||
loaded_weight,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
partition_id = self._shard_id_as_int(kwargs.pop("shard_id"))
|
||||
partition = self.partitions[partition_id]
|
||||
|
||||
input_dim = self.kwargs.get("input_dim")
|
||||
shard_size = partition.data.size(input_dim) // self.tp_size
|
||||
shard_offset = self.tp_rank * shard_size
|
||||
shard_id = "q" # fake first partition
|
||||
num_heads = kwargs.get("num_heads")
|
||||
|
||||
ModelWeightParameter.load_qkv_weight(
|
||||
partition,
|
||||
loaded_weight,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size,
|
||||
shard_id=shard_id,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
for key in self.partitions:
|
||||
self.partitions[key] = torch.nn.Parameter(
|
||||
data=self.partitions[key].data, requires_grad=False)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
raise ValueError("Accessing `data` of a "
|
||||
"`PartitionedModelWeightParameter` is not allowed. "
|
||||
"Instead, use `get_partition` to get the weight of "
|
||||
"the particular partition you want to access")
|
||||
|
||||
def _fake_weight_loader(self, param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_weight_shard_id: Optional[Union[str, int]]):
|
||||
raise ValueError("When loading partition weights of "
|
||||
f"{self.__class__.__name__}, use methods provided by "
|
||||
f"{self.__class__.__name__}, not partition loader")
|
||||
|
||||
|
||||
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
|
||||
output_dim: int, **kwargs) -> BasevLLMParameter:
|
||||
"""
|
||||
@@ -456,4 +596,4 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
|
||||
shard_offset=shard_offset,
|
||||
bitblas_tile_size=bitblas_tile_size)
|
||||
|
||||
return shard_size, shard_offset
|
||||
return shard_size, shard_offset
|
||||
|
||||
Reference in New Issue
Block a user