[Transform] [Quantization] Add transforms to compressed tensors (#22486)

This commit is contained in:
Kyle Sayers
2025-08-28 02:43:48 -04:00
committed by GitHub
parent c8851a4723
commit 22feac8e95
9 changed files with 661 additions and 36 deletions

View File

@@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat,
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.transform import TransformConfig
from pydantic import BaseModel
import vllm.envs as envs
@@ -30,6 +31,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod, get_linear_transform_schemes)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
@@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
transform_config: Optional[TransformConfig] = None,
):
super().__init__()
self.ignore = ignore
@@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig):
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config
if transform_config is not None:
self.transform_config = TransformConfig.model_validate(
transform_config)
else:
self.transform_config = None
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
@@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
if scheme is None:
return UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
input_tfms, output_tfms = get_linear_transform_schemes(
layer, prefix, self.transform_config,
self.packed_modules_mapping)
# choose quantization method
quant_method: LinearMethodBase = UnquantizedLinearMethod()
if quant_scheme is not None:
layer.scheme = quant_scheme
quant_method = CompressedTensorsLinearMethod(self)
# choose transform method
if any((input_tfms, output_tfms)):
return CompressedTensorsLinearTransformMethod.from_schemes(
quant_method, input_tfms, output_tfms)
else:
return quant_method
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
@@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig):
config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config)
transform_config = config.get("transform_config")
return cls(
target_scheme_map=target_scheme_map,
@@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
)
@classmethod
@@ -537,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
# TODO (@kylesayrs): support ignore module names with ct matching utils
if should_ignore_layer(layer_name,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return None
# Will be empty for models with only sparsity
weight_quant = input_quant = None
@@ -722,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
layer input. See LinearMethodBase for param details
"""
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")

View File

@@ -0,0 +1,227 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Generator
from itertools import accumulate
from typing import Callable, Optional
import torch
from compressed_tensors.transform import (TransformArgs, TransformConfig,
TransformLocation, TransformScheme)
from compressed_tensors.utils import is_match
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
LinearMethodBase,
QKVCrossParallelLinear)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
HadamardTransform)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
TransformTuple)
class CompressedTensorsLinearTransformMethod(LinearMethodBase):
"""
Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds
input and output transforms to either side of the original apply method
"""
@classmethod
def from_schemes(
cls, quant_method: LinearMethodBase, input_tfms: dict[int,
TransformTuple],
output_tfms: dict[int, TransformTuple]
) -> "CompressedTensorsLinearTransformMethod":
assert input_tfms or output_tfms
# TODO (@ksayers): implement QutlassLinearMethodNvFP4
# hadacore and fwht can be selected by Transform module
return cls(quant_method, input_tfms, output_tfms)
def __init__(self, quant_method: LinearMethodBase,
input_tfms: dict[int, TransformTuple],
output_tfms: dict[int, TransformTuple]):
self.quant_method = quant_method
self.input_tfms = input_tfms
self.output_tfms = output_tfms
self.input_transform: Optional[HadamardTransform] = None
self.output_transform: Optional[HadamardTransform] = None
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
# get weight loader for transforms
weight_loader: Callable = extra_weight_attrs.get(
"weight_loader") # type: ignore[assignment]
# HACK: UnquantizedLinearMethod does not support weight loader v2, but
# transforms (specifically SharedWeightParameter) requires
# weight loader v2. Until UnquantizedLinearMethod supports v2, we must
# hack around this by getting weight loader v1 so ULM can load correctly
quant_method_name = self.quant_method.__class__.__name__
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
if isinstance(layer, QKVCrossParallelLinear):
weight_loader_v1 = layer.weight_loader_v1
else:
weight_loader_v1 = layer.weight_loader
extra_weight_attrs["weight_loader"] = weight_loader_v1
self.quant_method.create_weights(
layer=layer,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
input_size=input_size,
output_size=output_size,
params_dtype=params_dtype,
**extra_weight_attrs)
# validate schemes
num_partitions = len(output_partition_sizes)
self._validate_tfm_schemes(num_partitions)
# create submodules for weight loading
if len(self.input_tfms) > 0:
scheme_name = list(self.input_tfms.values())[0].scheme_name
location = list(self.input_tfms.values())[0].args.location
transform_name = f"{scheme_name}_{location}"
transform = HadamardTransform(self.input_tfms, layer,
weight_loader,
input_size_per_partition,
output_partition_sizes)
layer.register_module(transform_name, transform)
self.input_transform = transform
if len(self.output_tfms) > 0:
scheme_name = list(self.output_tfms.values())[0].scheme_name
location = list(self.output_tfms.values())[0].args.location
transform_name = f"{scheme_name}_{location}"
transform = HadamardTransform(self.output_tfms, layer,
weight_loader,
input_size_per_partition,
output_partition_sizes)
layer.register_module(transform_name, transform)
self.output_transform = transform
# compute partition ranges for slicing activations
starts = [0] + list(accumulate(output_partition_sizes))[:-1]
self.partition_ranges = list(zip(starts, output_partition_sizes))
def process_weights_after_loading(self, layer):
self.quant_method.process_weights_after_loading(layer)
for submodule in layer.children():
if isinstance(submodule, HadamardTransform):
submodule.process_weights_after_loading()
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.input_transform is not None:
x = self.input_transform(x)
assert bias is None
x = self.quant_method.apply(layer, x, bias)
# TODO (@ksayers): Write a triton kernel to do this in parallel
if self.output_transform is not None:
for part_id, (start, length) in enumerate(self.partition_ranges):
x[:, start:start + length] = self.output_transform(
x[:, start:start + length], part_id=part_id)
return x
def _validate_tfm_schemes(self, num_partitions: int):
if len(self.input_tfms) > 0:
if 0 not in self.input_tfms:
raise ValueError("Must have same input")
for part_index in range(num_partitions):
if self.input_tfms[part_index] != self.input_tfms[0]:
raise ValueError("Must have same input")
if len(self.output_tfms) > 0:
scheme_name = list(self.output_tfms.values())[0].scheme_name
location = list(self.output_tfms.values())[0].args.location
for tfm in self.output_tfms.values():
if tfm.scheme_name != scheme_name:
raise ValueError("Must have same scheme name")
if tfm.args.location != location:
raise ValueError("Must have same location")
return self.input_tfms, self.output_tfms
def get_linear_transform_schemes(
layer: torch.nn.Module, layer_name: str,
transform_config: Optional[TransformConfig],
packed_modules_mapping: dict[str, list[str]]
) -> tuple[dict[int, TransformTuple], dict[
int, TransformTuple]]: # [input_transform, [output_transform, ...]]
# there can only be one transform input scheme per (fused) module
input_tfms = {}
output_tfms = {}
partition_names = get_layer_partition_names(layer_name,
packed_modules_mapping)
for scheme_name, scheme, args in get_schemes_args(transform_config):
for part_index, part_name in enumerate(partition_names):
if is_match(part_name, layer, args.targets,
args.ignore) and args.is_online():
if args.location == TransformLocation.INPUT:
input_tfms[part_index] = TransformTuple(
scheme_name, scheme, args)
elif args.location == TransformLocation.OUTPUT:
output_tfms[part_index] = TransformTuple(
scheme_name, scheme, args)
else:
raise ValueError(f"Cannot apply `{args.location}` "
f"transform to `{layer_name}`")
return (input_tfms, output_tfms)
def get_schemes_args(
transform_config: Optional[TransformConfig]
) -> Generator[tuple[str, TransformScheme, TransformArgs]]:
if transform_config is None:
return
for scheme_name, scheme in transform_config.config_groups.items():
for args in scheme.apply:
yield (scheme_name, scheme, args)
def get_layer_partition_names(
layer_name: str, packed_modules_mapping: dict[str,
list[str]]) -> list[str]:
"""
Get all partition names associated with this layer.
Names are returned in order of their partition indices.
```python
mapping = {"gate_up_proj", "gate_proj", "up_proj"}
assert get_layer_partition_names(
"mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"]
assert get_layer_partition_names(
"mlp.down_proj", mapping) == ["down_proj"]
"""
for fused_suffix, part_suffixes in packed_modules_mapping.items():
if layer_name.endswith(fused_suffix):
return [
layer_name.removesuffix(fused_suffix) + part_suffix
for part_suffix in part_suffixes
]
return [layer_name]

View File

@@ -0,0 +1,135 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Hashable
from typing import Callable, Optional
import torch
from compressed_tensors.transform import TransformLocation, TransformScheme
from torch import Tensor
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
TransformTuple)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parameter import SharedWeightParameter
class HadamardTransform(torch.nn.Module):
"""
Class which handles weight loading, postprocessing, and application of
transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
and attention transforms method (not implemented yet)
"""
transforms: dict[int, TransformTuple] # info parsed from transforms config
weight: SharedWeightParameter # container for shared tensors
kernel: Callable # function used during application
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
def __init__(self,
transforms: dict[int, TransformTuple],
layer: torch.nn.Module,
weight_loader: Callable,
input_size_per_partition: int,
output_partition_sizes: list[int],
kernel: Optional[Callable] = None):
super().__init__()
self.transforms = transforms
self.scales = {}
if get_tensor_model_parallel_world_size() > 1:
raise NotImplementedError("Online transforms with tensor "
"parallelism is not supported")
# Similar to row/col parallel params, but tensors are separate
# to allow for loading with shared memory
self.weight = SharedWeightParameter(weight_loader=weight_loader)
# create shared partition data for each partition of the original weight
input_size = input_size_per_partition
for part_index, (_scheme_name, scheme,
args) in self.transforms.items():
output_size = output_partition_sizes[part_index]
weight_size = self._get_weight_size(layer, args.location,
input_size, output_size)
data_key = self._get_data_key(scheme, weight_size)
self.weight.add_partition(
part_index,
data_key,
size=(weight_size, weight_size),
dtype=scheme.precision,
)
# validate that shared tensors and schemes are correct
self._validate_input_transforms()
# select kernel based on transform schemes
self.kernel = self._infer_kernel() if kernel is None else kernel
def process_weights_after_loading(self):
for part_id in self.weight.partitions:
data = self.weight.partitions[part_id].data
# required by torch.compile
self.weight.process_weights_after_loading()
# precompute scale as a runtime multiply, not division
# do not fold into weight in order to utilize FWHT
self.scales[part_id] = 1 / math.sqrt(data.size(0))
# FUTURE: avoid runtime tranpose by processing weights
# prior to apply
def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
if part_id not in self.weight.partitions:
return value
weight = self.weight.partitions[part_id]
weight = weight if self.transforms[
part_id].args.inverse else weight.T # linear := x(W.T)
scale = self.scales[part_id]
return self.kernel(self, value.to(weight.dtype), weight, None).to(
value.dtype) * scale
def _get_data_key(self, scheme: TransformScheme,
weight_size: int) -> Hashable:
return (id(scheme), weight_size)
def _get_weight_size(self, layer: torch.nn.Module,
location: TransformLocation, input_size: int,
output_size: int) -> int:
if isinstance(layer, LinearBase):
if location == TransformLocation.INPUT:
return input_size
elif location == TransformLocation.OUTPUT:
return output_size
elif isinstance(layer, VocabParallelEmbedding):
if location == TransformLocation.INPUT:
return output_size
elif location == TransformLocation.OUTPUT:
return input_size
raise ValueError()
def _validate_input_transforms(self):
assert len(self.transforms) > 0
location = list(self.transforms.values())[0].args.location
if location == TransformLocation.INPUT:
first_data = self.weight.partitions[0].data
for partition in self.weight.partitions.values():
if partition.data.data_ptr() != first_data.data_ptr():
raise ValueError("")
def _infer_kernel(self) -> Callable:
# TODO (@ksayers): use fwht, hadacore
return dispatch_unquantized_gemm()

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod)
# Because qutlass fuses hadamard with quantization, it cannot automatically be
# composed with kernels in the way CompressedTensorsLinearTransformMethod does.
# Therefore, a separate scheme must be created for each quantized dtype
class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod):
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# fused hadamard quant linear method
raise NotImplementedError()

View File

@@ -0,0 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple
from compressed_tensors.transform import TransformArgs, TransformScheme
__all__ = ["TransformTuple"]
class TransformTuple(NamedTuple):
scheme_name: str
scheme: TransformScheme
args: TransformArgs