[Quantization] fix: overflow with static per-tensor scaling (#29867)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Mickaël Seznec
2026-01-13 13:56:01 +01:00
committed by GitHub
parent 8c8653b672
commit a5bbbd2f24
2 changed files with 71 additions and 56 deletions

View File

@@ -5,7 +5,7 @@
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from types import MappingProxyType from types import MappingProxyType
from typing import ClassVar, NamedTuple from typing import TYPE_CHECKING, ClassVar, NamedTuple
import numpy import numpy
import torch import torch
@@ -15,6 +15,9 @@ from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
if TYPE_CHECKING:
from vllm.model_executor.layers.linear import LinearBase
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
@@ -239,7 +242,7 @@ def scaled_dequantize(
x_s: torch.Tensor, x_s: torch.Tensor,
group_shape: GroupShape | None = None, group_shape: GroupShape | None = None,
out_dtype: torch.dtype = torch.float32, out_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
if group_shape is not None: if group_shape is not None:
group_shape = _normalize_quant_group_shape(x_q, group_shape) group_shape = _normalize_quant_group_shape(x_q, group_shape)
@@ -270,6 +273,62 @@ def scaled_dequantize(
return (x_q.to(torch.float32) * x_s).to(out_dtype) return (x_q.to(torch.float32) * x_s).to(out_dtype)
def get_attribute_fallback(obj, attributes: list[str]):
for attr in attributes:
if hasattr(obj, attr):
return getattr(obj, attr)
raise AttributeError(f"'{obj}' has no recognized attributes: {attributes}.")
def get_and_maybe_dequant_weights(
layer: "LinearBase", out_dtype: torch.dtype = torch.float32
):
"""Return layer's unquantized weights in [out, in] layout"""
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
weight = get_attribute_fallback(layer, ["weight", "qweight", "weight_packed"])
# Unquantized layer: just return base weights
if layer.quant_method is None or isinstance(
layer.quant_method, UnquantizedLinearMethod
):
return weight.to(out_dtype)
# Simple Fp8 case: rescale with tensor or block weight scales
if (
isinstance(layer.quant_method, Fp8LinearMethod)
and not layer.quant_method.use_marlin
):
weight_scales = get_attribute_fallback(
layer, ["weight_scale", "weight_scale_inv"]
)
dequant_weights = scaled_dequantize(
weight,
weight_scales,
group_shape=layer.weight_block_size,
out_dtype=out_dtype,
)
# per-tensor scaling stores weights in [in, out] layout
if not layer.quant_method.block_quant:
dequant_weights = dequant_weights.T
return dequant_weights
# NOTE: Most generic base case
# - Call the layer with identity matrix which returns unquantized weights.
# - Must be used with extra care when dealing with static activation quantization:
# quantizing 1.0 may lead to over/underflows
# - Should only be used offline, since it's O(N^3)
assert hasattr(layer, "input_size_per_partition")
eye = torch.eye(
layer.input_size_per_partition,
dtype=out_dtype,
device=weight.device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None).to(out_dtype)
return dequant_weights.T
def pack_quantized_values_into_int32( def pack_quantized_values_into_int32(
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
): ):

View File

@@ -207,8 +207,9 @@ from vllm.model_executor.layers.batch_invariant import (
) )
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
LinearBase, )
UnquantizedLinearMethod, from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_and_maybe_dequant_weights,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_nvidia_artifactory from vllm.utils.flashinfer import has_nvidia_artifactory
@@ -1184,35 +1185,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for # we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low # the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T kv_b_proj_weight = get_and_maybe_dequant_weights(
self.kv_b_proj, out_dtype=act_dtype
).T
assert kv_b_proj_weight.shape == ( assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -1600,35 +1579,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out, lse.transpose(0, 1).contiguous() return attn_out, lse.transpose(0, 1).contiguous()
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for # we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low # the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T kv_b_proj_weight = get_and_maybe_dequant_weights(
self.kv_b_proj, out_dtype=act_dtype
).T
assert kv_b_proj_weight.shape == ( assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),