diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 32192225f..020098dff 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -122,15 +122,17 @@ def _per_token_quant_int8( def per_token_quant_int8(x): + original_shape = x.shape + if x.dim() > 2: + x = x.view(-1, original_shape[-1]) M = x.numel() // x.shape[-1] N = x.shape[-1] - x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) - scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32) + x_q = torch.empty((M, N), device=x.device, dtype=torch.int8) + scales = torch.empty((M, 1), device=x.device, dtype=torch.float32) BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) - - assert x.is_contiguous() + x = x.contiguous() _per_token_quant_int8[(M,)]( x, x_q, @@ -142,7 +144,8 @@ def per_token_quant_int8(x): num_warps=num_warps, num_stages=1, ) - + x_q = x_q.view(*original_shape) + scales = scales.view(*original_shape[:-1], 1) return x_q, scales