Use w8a8 quantized matmul Pallas kernel (#19170)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
XiongfeiWei
2025-07-14 20:06:33 -07:00
committed by GitHub
parent 946aadb4a0
commit d4170fad39
4 changed files with 50 additions and 19 deletions

View File

@@ -90,16 +90,15 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
out = torch.ops.xla.quantized_matmul(x,
w_q,
w_s,
zero_point=None,
block_size=-1,
int4_weight=False,
quantize_activation=True)
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
out = out.to(x.dtype)
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
out = torch.ops.xla.quantized_matmul_int8(
x,
w_q,
w_s,
quantize_activation=True,
)
# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])