[Quantization] fix: overflow with static per-tensor scaling (#29867)
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from types import MappingProxyType
|
||||
from typing import ClassVar, NamedTuple
|
||||
from typing import TYPE_CHECKING, ClassVar, NamedTuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -15,6 +15,9 @@ from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
@@ -239,7 +242,7 @@ def scaled_dequantize(
|
||||
x_s: torch.Tensor,
|
||||
group_shape: GroupShape | None = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> torch.Tensor:
|
||||
if group_shape is not None:
|
||||
group_shape = _normalize_quant_group_shape(x_q, group_shape)
|
||||
|
||||
@@ -270,6 +273,62 @@ def scaled_dequantize(
|
||||
return (x_q.to(torch.float32) * x_s).to(out_dtype)
|
||||
|
||||
|
||||
def get_attribute_fallback(obj, attributes: list[str]):
|
||||
for attr in attributes:
|
||||
if hasattr(obj, attr):
|
||||
return getattr(obj, attr)
|
||||
raise AttributeError(f"'{obj}' has no recognized attributes: {attributes}.")
|
||||
|
||||
|
||||
def get_and_maybe_dequant_weights(
|
||||
layer: "LinearBase", out_dtype: torch.dtype = torch.float32
|
||||
):
|
||||
"""Return layer's unquantized weights in [out, in] layout"""
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
|
||||
weight = get_attribute_fallback(layer, ["weight", "qweight", "weight_packed"])
|
||||
|
||||
# Unquantized layer: just return base weights
|
||||
if layer.quant_method is None or isinstance(
|
||||
layer.quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
return weight.to(out_dtype)
|
||||
|
||||
# Simple Fp8 case: rescale with tensor or block weight scales
|
||||
if (
|
||||
isinstance(layer.quant_method, Fp8LinearMethod)
|
||||
and not layer.quant_method.use_marlin
|
||||
):
|
||||
weight_scales = get_attribute_fallback(
|
||||
layer, ["weight_scale", "weight_scale_inv"]
|
||||
)
|
||||
dequant_weights = scaled_dequantize(
|
||||
weight,
|
||||
weight_scales,
|
||||
group_shape=layer.weight_block_size,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
# per-tensor scaling stores weights in [in, out] layout
|
||||
if not layer.quant_method.block_quant:
|
||||
dequant_weights = dequant_weights.T
|
||||
return dequant_weights
|
||||
|
||||
# NOTE: Most generic base case
|
||||
# - Call the layer with identity matrix which returns unquantized weights.
|
||||
# - Must be used with extra care when dealing with static activation quantization:
|
||||
# quantizing 1.0 may lead to over/underflows
|
||||
# - Should only be used offline, since it's O(N^3)
|
||||
assert hasattr(layer, "input_size_per_partition")
|
||||
eye = torch.eye(
|
||||
layer.input_size_per_partition,
|
||||
dtype=out_dtype,
|
||||
device=weight.device,
|
||||
)
|
||||
dequant_weights = layer.quant_method.apply(layer, eye, bias=None).to(out_dtype)
|
||||
return dequant_weights.T
|
||||
|
||||
|
||||
def pack_quantized_values_into_int32(
|
||||
w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0
|
||||
):
|
||||
|
||||
@@ -207,8 +207,9 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_and_maybe_dequant_weights,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_nvidia_artifactory
|
||||
@@ -1184,35 +1185,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
|
||||
)
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if layer.quant_method is not None and not isinstance(
|
||||
layer.quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(
|
||||
layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device,
|
||||
)
|
||||
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(
|
||||
self.kv_b_proj, out_dtype=act_dtype
|
||||
).T
|
||||
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
@@ -1600,35 +1579,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
return attn_out, lse.transpose(0, 1).contiguous()
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
|
||||
)
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if layer.quant_method is not None and not isinstance(
|
||||
layer.quant_method, UnquantizedLinearMethod
|
||||
):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(
|
||||
layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device,
|
||||
)
|
||||
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(
|
||||
self.kv_b_proj, out_dtype=act_dtype
|
||||
).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
|
||||
Reference in New Issue
Block a user