diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index c27ce3494..48beb977c 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -247,8 +247,8 @@ def scaled_dequantize( if group_shape is not None: group_shape = _normalize_quant_group_shape(x_q, group_shape) - if x_s.ndim == 0: # scalar - x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor + if x_s.numel() == 1: # scalar + x_s = x_s.reshape(1, 1) # normalize all scalar-like tensors to (1, 1) if x_s.ndim == 1: if group_shape is None: raise AssertionError(