diff --git a/tests/conftest.py b/tests/conftest.py index f8bfdfc8e..6052ada1c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +import math import os import tempfile from enum import Enum -from typing import Any, Callable, Optional, TypedDict, TypeVar, Union +from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast import numpy as np import pytest @@ -33,6 +34,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams +from vllm.sequence import Logprob from vllm.transformers_utils.utils import maybe_model_redirect logger = init_logger(__name__) @@ -602,7 +604,7 @@ class HfRunner: def _hidden_states_to_logprobs( self, hidden_states: tuple[tuple[torch.Tensor, ...], ...], - num_logprobs: int, + num_logprobs: Optional[int], ) -> tuple[list[dict[int, float]], int]: seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states) output_len = len(hidden_states) @@ -630,7 +632,7 @@ class HfRunner: self, prompts: list[str], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, @@ -677,7 +679,7 @@ class HfRunner: self, encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> list[TokensTextLogprobs]: @@ -966,7 +968,7 @@ class VllmRunner: self, prompts: list[str], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, @@ -991,11 +993,40 @@ class VllmRunner: videos=videos, **kwargs) + def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: + """ + Return the perplexity score associated with generating the prompts + + :param prompts: list of prompts to score + :return: perplexity score of each prompt + """ + outputs = self.generate_greedy_logprobs(prompts, + max_tokens=1, + num_logprobs=None, + num_prompt_logprobs=0) + + perplexities = [] + for output in outputs: + output = cast(TokensTextLogprobsPromptLogprobs, output) + token_datas = cast(list[Optional[dict[int, Logprob]]], output[3]) + assert token_datas[0] is None + token_log_probs = [] + for token_data in token_datas[1:]: + assert token_data is not None + assert len(token_data) == 1 + token_log_prob = list(token_data.values())[0].logprob + token_log_probs.append(token_log_prob) + + perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs)) + perplexities.append(perplexity) + + return perplexities + def generate_encoder_decoder_greedy_logprobs( self, encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], num_prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, ) -> Union[list[TokensTextLogprobs], diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index b9774b7ee..484f53246 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -719,3 +719,25 @@ def test_compressed_tensors_w4a8_fp8(vllm_runner, args): output = llm.generate_greedy("Hello my name is", max_tokens=20) print(output) assert output + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test is skipped on non-CUDA platform.") +@pytest.mark.parametrize("model,prompt,exp_perplexity", [ + ( + "nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", + "Flat is better than nested.\nSparse is better than dense.", + 150.0, + ), + ( + "nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", + "Flat is better than nested.\nSparse is better than dense.", + 150.0, + ), +]) +def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt, + exp_perplexity): + with vllm_runner(model, enforce_eager=True) as llm: + perplexity = llm.generate_prompt_perplexity([prompt])[0] + print(perplexity) + assert perplexity <= exp_perplexity \ No newline at end of file diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c0fcacd1e..19ff63145 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -35,6 +35,7 @@ logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", + "CompressedTensorsLinearTransformMethod", "BitBLASLinearMethod", "GPTQBitBLASLinearMethod", "AWQMarlinLinearMethod", @@ -199,6 +200,7 @@ class UnquantizedLinearMethod(LinearMethodBase): set_weight_attrs(weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # special postprocessing for CPU SGL if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: from vllm.model_executor.layers.utils import check_cpu_sgl_kernel N, K = layer.weight.size() @@ -1470,7 +1472,7 @@ class QKVCrossParallelLinear(LinearBase): self.bias = torch.nn.Parameter() set_weight_attrs(self.bias, { "output_dim": 0, - "weight_loader": self.weight_loader, + "weight_loader": self.weight_loader_v1, }) else: self.bias = None @@ -1580,6 +1582,18 @@ class QKVCrossParallelLinear(LinearBase): k, v = kv_enc.split(self.kv_size, dim=-1) return q, k, v + def weight_loader_v1(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + # just like all other parameters, does not yet + # support loading bias with weight_loader_v2 + layer = (self.q_proj_decoder + if loaded_shard_id == "q" else self.kv_proj_encoder) + target_param = self.select_proj_params(layer, param) + shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () + layer.weight_loader(target_param, loaded_weight, *shard_id_args) + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 245cf122e..230572041 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat, from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy, QuantizationType) +from compressed_tensors.transform import TransformConfig from pydantic import BaseModel import vllm.envs as envs @@ -30,6 +31,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 + CompressedTensorsLinearTransformMethod, get_linear_transform_schemes) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_ignore_list: list[str], kv_cache_scheme: Optional[dict[str, Any]] = None, config: Optional[dict[str, Any]] = None, + transform_config: Optional[TransformConfig] = None, ): super().__init__() self.ignore = ignore @@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig): self.sparsity_ignore_list = sparsity_ignore_list self.config = config + if transform_config is not None: + self.transform_config = TransformConfig.model_validate( + transform_config) + else: + self.transform_config = None + def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig): ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import - # Check if the layer is skipped for quantization. - # TODO (@robertgshaw2): support module names - if should_ignore_layer(prefix, - ignore=self.ignore, - fused_mapping=self.packed_modules_mapping): - return UnquantizedLinearMethod() if isinstance(layer, LinearBase): - scheme = self.get_scheme(layer=layer, layer_name=prefix) - if scheme is None: - return UnquantizedLinearMethod() - layer.scheme = scheme - return CompressedTensorsLinearMethod(self) + # collect schemes + quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) + input_tfms, output_tfms = get_linear_transform_schemes( + layer, prefix, self.transform_config, + self.packed_modules_mapping) + + # choose quantization method + quant_method: LinearMethodBase = UnquantizedLinearMethod() + if quant_scheme is not None: + layer.scheme = quant_scheme + quant_method = CompressedTensorsLinearMethod(self) + + # choose transform method + if any((input_tfms, output_tfms)): + return CompressedTensorsLinearTransformMethod.from_schemes( + quant_method, input_tfms, output_tfms) + + else: + return quant_method + if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): @@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig): config=config) sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( config=config) + transform_config = config.get("transform_config") return cls( target_scheme_map=target_scheme_map, @@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig): sparsity_scheme_map=sparsity_scheme_map, sparsity_ignore_list=sparsity_ignore_list, config=config, + transform_config=transform_config, ) @classmethod @@ -537,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig): # Find the "target" in the compressed-tensors config # that our layer conforms to. - # TODO (@robertgshaw): add compressed-tensors as dep - # so we do not have to re-write these functions - # need to make accelerate optional in ct to do this + # TODO (@kylesayrs): support ignore module names with ct matching utils + if should_ignore_layer(layer_name, + ignore=self.ignore, + fused_mapping=self.packed_modules_mapping): + return None # Will be empty for models with only sparsity weight_quant = input_quant = None @@ -722,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase): layer input. See LinearMethodBase for param details """ - scheme = layer.scheme if scheme is None: raise ValueError("A scheme must be defined for each layer") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py new file mode 100644 index 000000000..2fc94b3c2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Generator +from itertools import accumulate +from typing import Callable, Optional + +import torch +from compressed_tensors.transform import (TransformArgs, TransformConfig, + TransformLocation, TransformScheme) +from compressed_tensors.utils import is_match + +from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, + LinearMethodBase, + QKVCrossParallelLinear) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501 + HadamardTransform) +from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 + TransformTuple) + + +class CompressedTensorsLinearTransformMethod(LinearMethodBase): + """ + Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds + input and output transforms to either side of the original apply method + """ + + @classmethod + def from_schemes( + cls, quant_method: LinearMethodBase, input_tfms: dict[int, + TransformTuple], + output_tfms: dict[int, TransformTuple] + ) -> "CompressedTensorsLinearTransformMethod": + assert input_tfms or output_tfms + + # TODO (@ksayers): implement QutlassLinearMethodNvFP4 + # hadacore and fwht can be selected by Transform module + + return cls(quant_method, input_tfms, output_tfms) + + def __init__(self, quant_method: LinearMethodBase, + input_tfms: dict[int, TransformTuple], + output_tfms: dict[int, TransformTuple]): + self.quant_method = quant_method + self.input_tfms = input_tfms + self.output_tfms = output_tfms + + self.input_transform: Optional[HadamardTransform] = None + self.output_transform: Optional[HadamardTransform] = None + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + + # get weight loader for transforms + weight_loader: Callable = extra_weight_attrs.get( + "weight_loader") # type: ignore[assignment] + + # HACK: UnquantizedLinearMethod does not support weight loader v2, but + # transforms (specifically SharedWeightParameter) requires + # weight loader v2. Until UnquantizedLinearMethod supports v2, we must + # hack around this by getting weight loader v1 so ULM can load correctly + quant_method_name = self.quant_method.__class__.__name__ + if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED: + if isinstance(layer, QKVCrossParallelLinear): + weight_loader_v1 = layer.weight_loader_v1 + else: + weight_loader_v1 = layer.weight_loader + extra_weight_attrs["weight_loader"] = weight_loader_v1 + + self.quant_method.create_weights( + layer=layer, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + input_size=input_size, + output_size=output_size, + params_dtype=params_dtype, + **extra_weight_attrs) + + # validate schemes + num_partitions = len(output_partition_sizes) + self._validate_tfm_schemes(num_partitions) + + # create submodules for weight loading + if len(self.input_tfms) > 0: + scheme_name = list(self.input_tfms.values())[0].scheme_name + location = list(self.input_tfms.values())[0].args.location + transform_name = f"{scheme_name}_{location}" + + transform = HadamardTransform(self.input_tfms, layer, + weight_loader, + input_size_per_partition, + output_partition_sizes) + layer.register_module(transform_name, transform) + self.input_transform = transform + + if len(self.output_tfms) > 0: + scheme_name = list(self.output_tfms.values())[0].scheme_name + location = list(self.output_tfms.values())[0].args.location + transform_name = f"{scheme_name}_{location}" + + transform = HadamardTransform(self.output_tfms, layer, + weight_loader, + input_size_per_partition, + output_partition_sizes) + layer.register_module(transform_name, transform) + self.output_transform = transform + + # compute partition ranges for slicing activations + starts = [0] + list(accumulate(output_partition_sizes))[:-1] + self.partition_ranges = list(zip(starts, output_partition_sizes)) + + def process_weights_after_loading(self, layer): + self.quant_method.process_weights_after_loading(layer) + + for submodule in layer.children(): + if isinstance(submodule, HadamardTransform): + submodule.process_weights_after_loading() + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.input_transform is not None: + x = self.input_transform(x) + + assert bias is None + x = self.quant_method.apply(layer, x, bias) + + # TODO (@ksayers): Write a triton kernel to do this in parallel + if self.output_transform is not None: + for part_id, (start, length) in enumerate(self.partition_ranges): + x[:, start:start + length] = self.output_transform( + x[:, start:start + length], part_id=part_id) + + return x + + def _validate_tfm_schemes(self, num_partitions: int): + if len(self.input_tfms) > 0: + if 0 not in self.input_tfms: + raise ValueError("Must have same input") + + for part_index in range(num_partitions): + if self.input_tfms[part_index] != self.input_tfms[0]: + raise ValueError("Must have same input") + + if len(self.output_tfms) > 0: + scheme_name = list(self.output_tfms.values())[0].scheme_name + location = list(self.output_tfms.values())[0].args.location + + for tfm in self.output_tfms.values(): + if tfm.scheme_name != scheme_name: + raise ValueError("Must have same scheme name") + if tfm.args.location != location: + raise ValueError("Must have same location") + + return self.input_tfms, self.output_tfms + + +def get_linear_transform_schemes( + layer: torch.nn.Module, layer_name: str, + transform_config: Optional[TransformConfig], + packed_modules_mapping: dict[str, list[str]] +) -> tuple[dict[int, TransformTuple], dict[ + int, TransformTuple]]: # [input_transform, [output_transform, ...]] + # there can only be one transform input scheme per (fused) module + input_tfms = {} + output_tfms = {} + + partition_names = get_layer_partition_names(layer_name, + packed_modules_mapping) + + for scheme_name, scheme, args in get_schemes_args(transform_config): + for part_index, part_name in enumerate(partition_names): + if is_match(part_name, layer, args.targets, + args.ignore) and args.is_online(): + if args.location == TransformLocation.INPUT: + input_tfms[part_index] = TransformTuple( + scheme_name, scheme, args) + + elif args.location == TransformLocation.OUTPUT: + output_tfms[part_index] = TransformTuple( + scheme_name, scheme, args) + + else: + raise ValueError(f"Cannot apply `{args.location}` " + f"transform to `{layer_name}`") + + return (input_tfms, output_tfms) + + +def get_schemes_args( + transform_config: Optional[TransformConfig] +) -> Generator[tuple[str, TransformScheme, TransformArgs]]: + if transform_config is None: + return + + for scheme_name, scheme in transform_config.config_groups.items(): + for args in scheme.apply: + yield (scheme_name, scheme, args) + + +def get_layer_partition_names( + layer_name: str, packed_modules_mapping: dict[str, + list[str]]) -> list[str]: + """ + Get all partition names associated with this layer. + Names are returned in order of their partition indices. + + ```python + mapping = {"gate_up_proj", "gate_proj", "up_proj"} + + assert get_layer_partition_names( + "mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"] + assert get_layer_partition_names( + "mlp.down_proj", mapping) == ["down_proj"] + """ + for fused_suffix, part_suffixes in packed_modules_mapping.items(): + if layer_name.endswith(fused_suffix): + return [ + layer_name.removesuffix(fused_suffix) + part_suffix + for part_suffix in part_suffixes + ] + + return [layer_name] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py new file mode 100644 index 000000000..b3be25471 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Hashable +from typing import Callable, Optional + +import torch +from compressed_tensors.transform import TransformLocation, TransformScheme +from torch import Tensor + +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 + TransformTuple) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.parameter import SharedWeightParameter + + +class HadamardTransform(torch.nn.Module): + """ + Class which handles weight loading, postprocessing, and application of + transforms. Meant to be used with `CompressedTensorsLinearTransformMethod` + and attention transforms method (not implemented yet) + """ + transforms: dict[int, TransformTuple] # info parsed from transforms config + weight: SharedWeightParameter # container for shared tensors + + kernel: Callable # function used during application + scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0)) + + def __init__(self, + transforms: dict[int, TransformTuple], + layer: torch.nn.Module, + weight_loader: Callable, + input_size_per_partition: int, + output_partition_sizes: list[int], + kernel: Optional[Callable] = None): + super().__init__() + self.transforms = transforms + self.scales = {} + + if get_tensor_model_parallel_world_size() > 1: + raise NotImplementedError("Online transforms with tensor " + "parallelism is not supported") + + # Similar to row/col parallel params, but tensors are separate + # to allow for loading with shared memory + self.weight = SharedWeightParameter(weight_loader=weight_loader) + + # create shared partition data for each partition of the original weight + input_size = input_size_per_partition + for part_index, (_scheme_name, scheme, + args) in self.transforms.items(): + output_size = output_partition_sizes[part_index] + weight_size = self._get_weight_size(layer, args.location, + input_size, output_size) + + data_key = self._get_data_key(scheme, weight_size) + self.weight.add_partition( + part_index, + data_key, + size=(weight_size, weight_size), + dtype=scheme.precision, + ) + + # validate that shared tensors and schemes are correct + self._validate_input_transforms() + + # select kernel based on transform schemes + self.kernel = self._infer_kernel() if kernel is None else kernel + + def process_weights_after_loading(self): + for part_id in self.weight.partitions: + data = self.weight.partitions[part_id].data + + # required by torch.compile + self.weight.process_weights_after_loading() + + # precompute scale as a runtime multiply, not division + # do not fold into weight in order to utilize FWHT + self.scales[part_id] = 1 / math.sqrt(data.size(0)) + + # FUTURE: avoid runtime tranpose by processing weights + # prior to apply + + def forward(self, value: Tensor, part_id: int = 0) -> Tensor: + if part_id not in self.weight.partitions: + return value + + weight = self.weight.partitions[part_id] + weight = weight if self.transforms[ + part_id].args.inverse else weight.T # linear := x(W.T) + scale = self.scales[part_id] + return self.kernel(self, value.to(weight.dtype), weight, None).to( + value.dtype) * scale + + def _get_data_key(self, scheme: TransformScheme, + weight_size: int) -> Hashable: + return (id(scheme), weight_size) + + def _get_weight_size(self, layer: torch.nn.Module, + location: TransformLocation, input_size: int, + output_size: int) -> int: + if isinstance(layer, LinearBase): + if location == TransformLocation.INPUT: + return input_size + + elif location == TransformLocation.OUTPUT: + return output_size + + elif isinstance(layer, VocabParallelEmbedding): + if location == TransformLocation.INPUT: + return output_size + + elif location == TransformLocation.OUTPUT: + return input_size + + raise ValueError() + + def _validate_input_transforms(self): + assert len(self.transforms) > 0 + location = list(self.transforms.values())[0].args.location + + if location == TransformLocation.INPUT: + first_data = self.weight.partitions[0].data + for partition in self.weight.partitions.values(): + if partition.data.data_ptr() != first_data.data_ptr(): + raise ValueError("") + + def _infer_kernel(self) -> Callable: + # TODO (@ksayers): use fwht, hadacore + return dispatch_unquantized_gemm() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py new file mode 100644 index 000000000..f42258f9f --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 + CompressedTensorsLinearTransformMethod) + + +# Because qutlass fuses hadamard with quantization, it cannot automatically be +# composed with kernels in the way CompressedTensorsLinearTransformMethod does. +# Therefore, a separate scheme must be created for each quantized dtype +class QutlassLinearMethodNvFP4(CompressedTensorsLinearTransformMethod): + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # fused hadamard quant linear method + raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py new file mode 100644 index 000000000..2f353de1e --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import NamedTuple + +from compressed_tensors.transform import TransformArgs, TransformScheme + +__all__ = ["TransformTuple"] + + +class TransformTuple(NamedTuple): + scheme_name: str + scheme: TransformScheme + args: TransformArgs diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 750ee7850..9465308e9 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Hashable from fractions import Fraction from typing import Callable, Optional, Union +from weakref import WeakValueDictionary import torch from torch.nn import Parameter -from vllm.distributed import get_tensor_model_parallel_rank +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.model_executor.utils import _make_synced_weight_loader @@ -27,7 +30,7 @@ class BasevLLMParameter(Parameter): into the parameter when the provided weight loader is called. """ - def __new__(cls, data: torch.Tensor, **kwargs): + def __new__(cls, data: Optional[torch.Tensor], **kwargs): return super().__new__(cls, data=data, requires_grad=False) @@ -81,6 +84,17 @@ class BasevLLMParameter(Parameter): def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): self._assert_and_load(loaded_weight) + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + # if not int, assume shard_id for qkv + # map to int and return + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert isinstance(shard_id, str) + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + class _ColumnvLLMParameter(BasevLLMParameter): """ @@ -113,6 +127,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") + # TODO: move these to PackedColumnParameter and PackedvLLMParameter if isinstance( self, (PackedColumnParameter, @@ -137,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): shard_id = kwargs.get("shard_id") num_heads = kwargs.get("num_heads") + # TODO: move these to PackedColumnParameter and PackedvLLMParameter if isinstance( self, (PackedColumnParameter, @@ -224,19 +240,8 @@ class PerTensorScaleParameter(BasevLLMParameter): """ def __init__(self, **kwargs): - self.qkv_idxs = {"q": 0, "k": 1, "v": 2} super().__init__(**kwargs) - def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: - if isinstance(shard_id, int): - return shard_id - - # if not int, assume shard_id for qkv - # map to int and return - assert isinstance(shard_id, str) - assert shard_id in self.qkv_idxs - return self.qkv_idxs[shard_id] - # For row parallel layers, no sharding needed # load weight into parameter as is def load_row_parallel_weight(self, *args, **kwargs): @@ -373,6 +378,141 @@ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): pass +class SharedWeightParameter(BasevLLMParameter): + """ + Parameter for weights with many shared tensors across a model + + For example, when applying transforms to the "gate" and "up" partitions of + `MergedColumnParallelLinear`, the transform weights must stay separate + tensors in order to allow for tensor memory sharing between layers. + """ + # global registry for sharing tensors based on passed `data_key` + # this dict holds weaksrefs to avoid memory leak after model cleanup + tensors_registry: WeakValueDictionary = WeakValueDictionary() + + # local container for strong references to shared tensors + # this set compensates for the fact that torch.nn.Parameter + # and Parameter subclasses do not hold reliable references to tensors + local_tensors: set[torch.Tensor] + + # dictionary mapping partition indices to associated parameters + partitions: dict[int, Union[ModelWeightParameter, Parameter]] + + def __new__(cls, **kwargs): + return super().__new__(cls, data=None, **kwargs) + + def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs): + weight_loader: Callable = kwargs.get( + "weight_loader") # type: ignore[assignment] + super().__init__(data=None, weight_loader=weight_loader) + + self.local_tensors = set() + self.partitions = {} + self.kwargs = { + "input_dim": input_dim, + "output_dim": output_dim, + "weight_loader": self._fake_weight_loader + } + + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > 1: + raise NotImplementedError(f"{self.__class__.__name__} does not " + "currently support tensor parallelism") + + def add_partition(self, index: int, data_key: Hashable, *args, **kwargs): + """ + Add a partition to the weight parameter. Partitions whose `data_key` + is the same will share tensor data + + :param index: index of partition to add + :param data_key: hashable key used to key shared tensors + :param *args: arguments for `torch.empty` + :param **kwargs: keyword arguments for `torch.empty` + """ + # load (shared) tensor using `data_key` + if data_key not in self.tensors_registry: + data = torch.empty(*args, **kwargs) + self.tensors_registry[data_key] = data + else: + data = self.tensors_registry[data_key] + + # create associated model parameter + self.partitions[index] = ModelWeightParameter( + data=data, **self.kwargs) # type: ignore[arg-type] + + # hold local reference, since ModelWeightParameter does not + # see https://github.com/pytorch/pytorch/issues/75932 + self.local_tensors.add(data) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + assert len(self.partitions) == 1 and 0 in self.partitions + partition = self.partitions[0] + + ModelWeightParameter.load_column_parallel_weight( + partition, loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + assert len(self.partitions) == 1 and 0 in self.partitions + partition = self.partitions[0] + + ModelWeightParameter.load_row_parallel_weight(partition, loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + partition_id = kwargs.pop("shard_id") + partition_id = self._shard_id_as_int(partition_id) + partition = self.partitions[partition_id] + + input_dim = self.kwargs.get("input_dim") + shard_size = partition.data.size(input_dim) // self.tp_size + shard_offset = self.tp_rank * shard_size + + ModelWeightParameter.load_merged_column_weight( + partition, + loaded_weight, + shard_offset=shard_offset, + shard_size=shard_size) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + partition_id = self._shard_id_as_int(kwargs.pop("shard_id")) + partition = self.partitions[partition_id] + + input_dim = self.kwargs.get("input_dim") + shard_size = partition.data.size(input_dim) // self.tp_size + shard_offset = self.tp_rank * shard_size + shard_id = "q" # fake first partition + num_heads = kwargs.get("num_heads") + + ModelWeightParameter.load_qkv_weight( + partition, + loaded_weight, + shard_offset=shard_offset, + shard_size=shard_size, + shard_id=shard_id, + num_heads=num_heads, + ) + + def process_weights_after_loading(self): + for key in self.partitions: + self.partitions[key] = torch.nn.Parameter( + data=self.partitions[key].data, requires_grad=False) + + @property + def data(self): + raise ValueError("Accessing `data` of a " + "`PartitionedModelWeightParameter` is not allowed. " + "Instead, use `get_partition` to get the weight of " + "the particular partition you want to access") + + def _fake_weight_loader(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_weight_shard_id: Optional[Union[str, int]]): + raise ValueError("When loading partition weights of " + f"{self.__class__.__name__}, use methods provided by " + f"{self.__class__.__name__}, not partition loader") + + def permute_param_layout_(param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs) -> BasevLLMParameter: """ @@ -456,4 +596,4 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_offset=shard_offset, bitblas_tile_size=bitblas_tile_size) - return shard_size, shard_offset \ No newline at end of file + return shard_size, shard_offset