[TPU][V1][Bugfix] Fix w8a8 recompiilation with GSM8K (#15714)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -97,7 +97,8 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
block_size=-1,
|
||||
int4_weight=False,
|
||||
quantize_activation=True)
|
||||
|
||||
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
|
||||
out = out.to(x.dtype)
|
||||
# Explicitly capture control flow to make dynamo happy.
|
||||
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
||||
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
||||
|
||||
Reference in New Issue
Block a user