[Quantization] fix: overflow with static per-tensor scaling (#29867)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Mickaël Seznec
2026-01-13 13:56:01 +01:00
committed by GitHub
parent 8c8653b672
commit a5bbbd2f24
2 changed files with 71 additions and 56 deletions

View File

@@ -207,8 +207,9 @@ from vllm.model_executor.layers.batch_invariant import (
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_and_maybe_dequant_weights,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_nvidia_artifactory
@@ -1184,35 +1185,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
kv_b_proj_weight = get_and_maybe_dequant_weights(
self.kv_b_proj, out_dtype=act_dtype
).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -1600,35 +1579,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return attn_out, lse.transpose(0, 1).contiguous()
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device,
)
dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
kv_b_proj_weight = get_and_maybe_dequant_weights(
self.kv_b_proj, out_dtype=act_dtype
).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),