From a5bbbd2f24983f309f915a1de911b9301e1699e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micka=C3=ABl=20Seznec?= Date: Tue, 13 Jan 2026 13:56:01 +0100 Subject: [PATCH] [Quantization] fix: overflow with static per-tensor scaling (#29867) Signed-off-by: Mickael Seznec Signed-off-by: Michael Goin Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Michael Goin --- .../layers/quantization/utils/quant_utils.py | 63 +++++++++++++++++- vllm/v1/attention/backends/mla/common.py | 64 +++---------------- 2 files changed, 71 insertions(+), 56 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index a6c4c702e..08262ed1a 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass from types import MappingProxyType -from typing import ClassVar, NamedTuple +from typing import TYPE_CHECKING, ClassVar, NamedTuple import numpy import torch @@ -15,6 +15,9 @@ from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +if TYPE_CHECKING: + from vllm.model_executor.layers.linear import LinearBase + FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -239,7 +242,7 @@ def scaled_dequantize( x_s: torch.Tensor, group_shape: GroupShape | None = None, out_dtype: torch.dtype = torch.float32, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: if group_shape is not None: group_shape = _normalize_quant_group_shape(x_q, group_shape) @@ -270,6 +273,62 @@ def scaled_dequantize( return (x_q.to(torch.float32) * x_s).to(out_dtype) +def get_attribute_fallback(obj, attributes: list[str]): + for attr in attributes: + if hasattr(obj, attr): + return getattr(obj, attr) + raise AttributeError(f"'{obj}' has no recognized attributes: {attributes}.") + + +def get_and_maybe_dequant_weights( + layer: "LinearBase", out_dtype: torch.dtype = torch.float32 +): + """Return layer's unquantized weights in [out, in] layout""" + from vllm.model_executor.layers.linear import UnquantizedLinearMethod + from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod + + weight = get_attribute_fallback(layer, ["weight", "qweight", "weight_packed"]) + + # Unquantized layer: just return base weights + if layer.quant_method is None or isinstance( + layer.quant_method, UnquantizedLinearMethod + ): + return weight.to(out_dtype) + + # Simple Fp8 case: rescale with tensor or block weight scales + if ( + isinstance(layer.quant_method, Fp8LinearMethod) + and not layer.quant_method.use_marlin + ): + weight_scales = get_attribute_fallback( + layer, ["weight_scale", "weight_scale_inv"] + ) + dequant_weights = scaled_dequantize( + weight, + weight_scales, + group_shape=layer.weight_block_size, + out_dtype=out_dtype, + ) + # per-tensor scaling stores weights in [in, out] layout + if not layer.quant_method.block_quant: + dequant_weights = dequant_weights.T + return dequant_weights + + # NOTE: Most generic base case + # - Call the layer with identity matrix which returns unquantized weights. + # - Must be used with extra care when dealing with static activation quantization: + # quantizing 1.0 may lead to over/underflows + # - Should only be used offline, since it's O(N^3) + assert hasattr(layer, "input_size_per_partition") + eye = torch.eye( + layer.input_size_per_partition, + dtype=out_dtype, + device=weight.device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None).to(out_dtype) + return dequant_weights.T + + def pack_quantized_values_into_int32( w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 ): diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 5cd5cc566..aa2740cac 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -207,8 +207,9 @@ from vllm.model_executor.layers.batch_invariant import ( ) from vllm.model_executor.layers.linear import ( ColumnParallelLinear, - LinearBase, - UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_and_maybe_dequant_weights, ) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_nvidia_artifactory @@ -1184,35 +1185,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() def process_weights_after_loading(self, act_dtype: torch.dtype): - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - if hasattr(layer, attr): - return getattr(layer, attr) - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." - ) - - def get_and_maybe_dequant_weights(layer: LinearBase): - if layer.quant_method is not None and not isinstance( - layer.quant_method, UnquantizedLinearMethod - ): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye( - layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device, - ) - dequant_weights = layer.quant_method.apply(layer, eye, bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + kv_b_proj_weight = get_and_maybe_dequant_weights( + self.kv_b_proj, out_dtype=act_dtype + ).T + assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1600,35 +1579,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return attn_out, lse.transpose(0, 1).contiguous() def process_weights_after_loading(self, act_dtype: torch.dtype): - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - if hasattr(layer, attr): - return getattr(layer, attr) - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." - ) - - def get_and_maybe_dequant_weights(layer: LinearBase): - if layer.quant_method is not None and not isinstance( - layer.quant_method, UnquantizedLinearMethod - ): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye( - layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device, - ) - dequant_weights = layer.quant_method.apply(layer, eye, bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + kv_b_proj_weight = get_and_maybe_dequant_weights( + self.kv_b_proj, out_dtype=act_dtype + ).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),