[TPU] Add support for online w8a8 quantization (#22425)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
|
||||
ACTIVATION_SCHEMES = ["none"]
|
||||
ACTIVATION_SCHEMES = ["none", "dynamic"]
|
||||
|
||||
|
||||
class Int8TpuConfig(QuantizationConfig):
|
||||
@@ -61,6 +61,9 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: Int8TpuConfig):
|
||||
self.quant_config = quant_config
|
||||
self.quantize_activation = False
|
||||
if self.quant_config.activation_scheme == 'dynamic':
|
||||
self.quantize_activation = True
|
||||
|
||||
def create_weights(self, layer: Module, input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
@@ -107,7 +110,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
try:
|
||||
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install torch_xla by following the instructions at "
|
||||
@@ -115,7 +118,8 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
||||
"to run vLLM on TPU.") from err
|
||||
weight = layer.weight
|
||||
scale = layer.scale
|
||||
out = torch.ops.xla.quantized_matmul(x, weight, scale)
|
||||
out = torch.ops.xla.quantized_matmul_int8(
|
||||
x, weight, scale, quantize_activation=self.quantize_activation)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user