[Quantization] fix: overflow with static per-tensor scaling (#29867)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Mickaël Seznec
2026-01-13 13:56:01 +01:00
committed by GitHub
parent 8c8653b672
commit a5bbbd2f24
2 changed files with 71 additions and 56 deletions

View File

@@ -5,7 +5,7 @@
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from types import MappingProxyType
from typing import ClassVar, NamedTuple
from typing import TYPE_CHECKING, ClassVar, NamedTuple
import numpy
import torch
@@ -15,6 +15,9 @@ from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
if TYPE_CHECKING:
from vllm.model_executor.layers.linear import LinearBase
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
@@ -239,7 +242,7 @@ def scaled_dequantize(
x_s: torch.Tensor,
group_shape: GroupShape | None = None,
out_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
if group_shape is not None:
group_shape = _normalize_quant_group_shape(x_q, group_shape)
@@ -270,6 +273,62 @@ def scaled_dequantize(
return (x_q.to(torch.float32) * x_s).to(out_dtype)
def get_attribute_fallback(obj, attributes: list[str]):
for attr in attributes:
if hasattr(obj, attr):
return getattr(obj, attr)
raise AttributeError(f"'{obj}' has no recognized attributes: {attributes}.")
def get_and_maybe_dequant_weights(
layer: "LinearBase", out_dtype: torch.dtype = torch.float32
):
"""Return layer's unquantized weights in [out, in] layout"""
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
weight = get_attribute_fallback(layer, ["weight", "qweight", "weight_packed"])
# Unquantized layer: just return base weights
if layer.quant_method is None or isinstance(
layer.quant_method, UnquantizedLinearMethod
):
return weight.to(out_dtype)
# Simple Fp8 case: rescale with tensor or block weight scales
if (
isinstance(layer.quant_method, Fp8LinearMethod)
and not layer.quant_method.use_marlin
):
weight_scales = get_attribute_fallback(
layer, ["weight_scale", "weight_scale_inv"]
)
dequant_weights = scaled_dequantize(
weight,
weight_scales,
group_shape=layer.weight_block_size,
out_dtype=out_dtype,
)
# per-tensor scaling stores weights in [in, out] layout
if not layer.quant_method.block_quant:
dequant_weights = dequant_weights.T
return dequant_weights
# NOTE: Most generic base case
# - Call the layer with identity matrix which returns unquantized weights.
# - Must be used with extra care when dealing with static activation quantization:
# quantizing 1.0 may lead to over/underflows
# - Should only be used offline, since it's O(N^3)
assert hasattr(layer, "input_size_per_partition")
eye = torch.eye(
layer.input_size_per_partition,
dtype=out_dtype,
device=weight.device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None).to(out_dtype)
return dequant_weights.T
def pack_quantized_values_into_int32(
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
):

View File

@@ -207,8 +207,9 @@ from vllm.model_executor.layers.batch_invariant import (
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_and_maybe_dequant_weights,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_nvidia_artifactory
@@ -1184,35 +1185,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
kv_b_proj_weight = get_and_maybe_dequant_weights(
self.kv_b_proj, out_dtype=act_dtype
).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -1600,35 +1579,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out, lse.transpose(0, 1).contiguous()
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
kv_b_proj_weight = get_and_maybe_dequant_weights(
self.kv_b_proj, out_dtype=act_dtype
).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),