[Transform] [Quantization] Add transforms to compressed tensors (#22486)

This commit is contained in:
Kyle Sayers
2025-08-28 02:43:48 -04:00
committed by GitHub
parent c8851a4723
commit 22feac8e95
9 changed files with 661 additions and 36 deletions

View File

@@ -1,13 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Hashable
from fractions import Fraction
from typing import Callable, Optional, Union
from weakref import WeakValueDictionary
import torch
from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.utils import _make_synced_weight_loader
@@ -27,7 +30,7 @@ class BasevLLMParameter(Parameter):
into the parameter when the provided weight loader is called.
"""
def __new__(cls, data: torch.Tensor, **kwargs):
def __new__(cls, data: Optional[torch.Tensor], **kwargs):
return super().__new__(cls, data=data, requires_grad=False)
@@ -81,6 +84,17 @@ class BasevLLMParameter(Parameter):
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
self._assert_and_load(loaded_weight)
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
# if not int, assume shard_id for qkv
# map to int and return
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert isinstance(shard_id, str)
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]
class _ColumnvLLMParameter(BasevLLMParameter):
"""
@@ -113,6 +127,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if isinstance(
self,
(PackedColumnParameter,
@@ -137,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_id = kwargs.get("shard_id")
num_heads = kwargs.get("num_heads")
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if isinstance(
self,
(PackedColumnParameter,
@@ -224,19 +240,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
"""
def __init__(self, **kwargs):
self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
super().__init__(**kwargs)
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
# if not int, assume shard_id for qkv
# map to int and return
assert isinstance(shard_id, str)
assert shard_id in self.qkv_idxs
return self.qkv_idxs[shard_id]
# For row parallel layers, no sharding needed
# load weight into parameter as is
def load_row_parallel_weight(self, *args, **kwargs):
@@ -373,6 +378,141 @@ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
pass
class SharedWeightParameter(BasevLLMParameter):
"""
Parameter for weights with many shared tensors across a model
For example, when applying transforms to the "gate" and "up" partitions of
`MergedColumnParallelLinear`, the transform weights must stay separate
tensors in order to allow for tensor memory sharing between layers.
"""
# global registry for sharing tensors based on passed `data_key`
# this dict holds weaksrefs to avoid memory leak after model cleanup
tensors_registry: WeakValueDictionary = WeakValueDictionary()
# local container for strong references to shared tensors
# this set compensates for the fact that torch.nn.Parameter
# and Parameter subclasses do not hold reliable references to tensors
local_tensors: set[torch.Tensor]
# dictionary mapping partition indices to associated parameters
partitions: dict[int, Union[ModelWeightParameter, Parameter]]
def __new__(cls, **kwargs):
return super().__new__(cls, data=None, **kwargs)
def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs):
weight_loader: Callable = kwargs.get(
"weight_loader") # type: ignore[assignment]
super().__init__(data=None, weight_loader=weight_loader)
self.local_tensors = set()
self.partitions = {}
self.kwargs = {
"input_dim": input_dim,
"output_dim": output_dim,
"weight_loader": self._fake_weight_loader
}
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > 1:
raise NotImplementedError(f"{self.__class__.__name__} does not "
"currently support tensor parallelism")
def add_partition(self, index: int, data_key: Hashable, *args, **kwargs):
"""
Add a partition to the weight parameter. Partitions whose `data_key`
is the same will share tensor data
:param index: index of partition to add
:param data_key: hashable key used to key shared tensors
:param *args: arguments for `torch.empty`
:param **kwargs: keyword arguments for `torch.empty`
"""
# load (shared) tensor using `data_key`
if data_key not in self.tensors_registry:
data = torch.empty(*args, **kwargs)
self.tensors_registry[data_key] = data
else:
data = self.tensors_registry[data_key]
# create associated model parameter
self.partitions[index] = ModelWeightParameter(
data=data, **self.kwargs) # type: ignore[arg-type]
# hold local reference, since ModelWeightParameter does not
# see https://github.com/pytorch/pytorch/issues/75932
self.local_tensors.add(data)
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
assert len(self.partitions) == 1 and 0 in self.partitions
partition = self.partitions[0]
ModelWeightParameter.load_column_parallel_weight(
partition, loaded_weight)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
assert len(self.partitions) == 1 and 0 in self.partitions
partition = self.partitions[0]
ModelWeightParameter.load_row_parallel_weight(partition, loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
partition_id = kwargs.pop("shard_id")
partition_id = self._shard_id_as_int(partition_id)
partition = self.partitions[partition_id]
input_dim = self.kwargs.get("input_dim")
shard_size = partition.data.size(input_dim) // self.tp_size
shard_offset = self.tp_rank * shard_size
ModelWeightParameter.load_merged_column_weight(
partition,
loaded_weight,
shard_offset=shard_offset,
shard_size=shard_size)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
partition_id = self._shard_id_as_int(kwargs.pop("shard_id"))
partition = self.partitions[partition_id]
input_dim = self.kwargs.get("input_dim")
shard_size = partition.data.size(input_dim) // self.tp_size
shard_offset = self.tp_rank * shard_size
shard_id = "q" # fake first partition
num_heads = kwargs.get("num_heads")
ModelWeightParameter.load_qkv_weight(
partition,
loaded_weight,
shard_offset=shard_offset,
shard_size=shard_size,
shard_id=shard_id,
num_heads=num_heads,
)
def process_weights_after_loading(self):
for key in self.partitions:
self.partitions[key] = torch.nn.Parameter(
data=self.partitions[key].data, requires_grad=False)
@property
def data(self):
raise ValueError("Accessing `data` of a "
"`PartitionedModelWeightParameter` is not allowed. "
"Instead, use `get_partition` to get the weight of "
"the particular partition you want to access")
def _fake_weight_loader(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_weight_shard_id: Optional[Union[str, int]]):
raise ValueError("When loading partition weights of "
f"{self.__class__.__name__}, use methods provided by "
f"{self.__class__.__name__}, not partition loader")
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
output_dim: int, **kwargs) -> BasevLLMParameter:
"""
@@ -456,4 +596,4 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
shard_offset=shard_offset,
bitblas_tile_size=bitblas_tile_size)
return shard_size, shard_offset
return shard_size, shard_offset