Use w8a8 quantized matmul Pallas kernel (#19170)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
@@ -90,16 +90,15 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_weight_params(layer)
|
||||
|
||||
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
|
||||
out = torch.ops.xla.quantized_matmul(x,
|
||||
w_q,
|
||||
w_s,
|
||||
zero_point=None,
|
||||
block_size=-1,
|
||||
int4_weight=False,
|
||||
quantize_activation=True)
|
||||
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
|
||||
out = out.to(x.dtype)
|
||||
# Required to register custom ops.
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
out = torch.ops.xla.quantized_matmul_int8(
|
||||
x,
|
||||
w_q,
|
||||
w_s,
|
||||
quantize_activation=True,
|
||||
)
|
||||
|
||||
# Explicitly capture control flow to make dynamo happy.
|
||||
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
||||
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
||||
|
||||
Reference in New Issue
Block a user