[Quantization] fix: overflow with static per-tensor scaling (#29867)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Mickaël Seznec
2026-01-13 13:56:01 +01:00
committed by GitHub
parent 8c8653b672
commit a5bbbd2f24
2 changed files with 71 additions and 56 deletions

View File

@@ -5,7 +5,7 @@
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from types import MappingProxyType
from typing import ClassVar, NamedTuple
from typing import TYPE_CHECKING, ClassVar, NamedTuple
import numpy
import torch
@@ -15,6 +15,9 @@ from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
if TYPE_CHECKING:
from vllm.model_executor.layers.linear import LinearBase
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
@@ -239,7 +242,7 @@ def scaled_dequantize(
x_s: torch.Tensor,
group_shape: GroupShape | None = None,
out_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
if group_shape is not None:
group_shape = _normalize_quant_group_shape(x_q, group_shape)
@@ -270,6 +273,62 @@ def scaled_dequantize(
return (x_q.to(torch.float32) * x_s).to(out_dtype)
def get_attribute_fallback(obj, attributes: list[str]):
for attr in attributes:
if hasattr(obj, attr):
return getattr(obj, attr)
raise AttributeError(f"'{obj}' has no recognized attributes: {attributes}.")
def get_and_maybe_dequant_weights(
layer: "LinearBase", out_dtype: torch.dtype = torch.float32
):
"""Return layer's unquantized weights in [out, in] layout"""
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
weight = get_attribute_fallback(layer, ["weight", "qweight", "weight_packed"])
# Unquantized layer: just return base weights
if layer.quant_method is None or isinstance(
layer.quant_method, UnquantizedLinearMethod
):
return weight.to(out_dtype)
# Simple Fp8 case: rescale with tensor or block weight scales
if (
isinstance(layer.quant_method, Fp8LinearMethod)
and not layer.quant_method.use_marlin
):
weight_scales = get_attribute_fallback(
layer, ["weight_scale", "weight_scale_inv"]
)
dequant_weights = scaled_dequantize(
weight,
weight_scales,
group_shape=layer.weight_block_size,
out_dtype=out_dtype,
)
# per-tensor scaling stores weights in [in, out] layout
if not layer.quant_method.block_quant:
dequant_weights = dequant_weights.T
return dequant_weights
# NOTE: Most generic base case
# - Call the layer with identity matrix which returns unquantized weights.
# - Must be used with extra care when dealing with static activation quantization:
# quantizing 1.0 may lead to over/underflows
# - Should only be used offline, since it's O(N^3)
assert hasattr(layer, "input_size_per_partition")
eye = torch.eye(
layer.input_size_per_partition,
dtype=out_dtype,
device=weight.device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None).to(out_dtype)
return dequant_weights.T
def pack_quantized_values_into_int32(
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
):